mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
added file to convert pytorch->tf
This commit is contained in:
parent
96c2b77f0f
commit
968c1b44cb
@ -21,26 +21,6 @@ import numpy as np
|
||||
from pytorch_pretrained_bert.modeling import BertConfig, BertModel
|
||||
|
||||
|
||||
# def __get_var_names(config):
|
||||
#
|
||||
# models = {
|
||||
# 'BertModel': BertModel(config),
|
||||
# 'BertForMaskedLM': BertForMaskedLM(config),
|
||||
# 'BertForPreTraining': BertForPreTraining(config),
|
||||
# 'BertForMultipleChoice': BertForMultipleChoice(config, num_choices=100),
|
||||
# 'BertForNextSentencePrediction': BertForNextSentencePrediction(config),
|
||||
# 'BertForSequenceClassification': BertForSequenceClassification(config, num_labels=100),
|
||||
# 'BertForQuestionAnswering': BertForQuestionAnswering(config)
|
||||
# }
|
||||
#
|
||||
# for name, model in models.items():
|
||||
# state_dict = model.state_dict()
|
||||
# torch_vars = []
|
||||
# for var_ in state_dict:
|
||||
# torch_vars.append(var_ + ', ' + str(tuple(state_dict[var_].shape)))
|
||||
# json.dump(torch_vars, fp=open('torch_var_names_{}.json'.format(name), 'w'), indent=3)
|
||||
|
||||
|
||||
|
||||
def convert_hf_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
|
||||
|
||||
@ -58,8 +38,7 @@ def convert_hf_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
|
||||
N BertForQuestionAnswering
|
||||
|
||||
Note:
|
||||
TF isn't & shouldn't be a package-level requirement; this
|
||||
feature is requested enough to warrant a local import.
|
||||
To keep TF out of package-level requirements, tf is imported locally.
|
||||
"""
|
||||
|
||||
import tensorflow as tf
|
||||
|
Loading…
Reference in New Issue
Block a user