# Pytorch to Tensorflow Conversion Test Notebook

To run this notebook follow these steps, modifying the **Config** section as necessary:

1. Point `pt_model_dir` to your local directory containing the pytorch Bert model to be converted.
2. Point `tf_bert_dir` to your clone of Google's Bert implementation which can be found here: https://github.com/google-research/bert.

Note: 
1. This feature currently only supports the base BERT models (uncased/cased).
2. Tensorflow model will be dumped in `tf_model_dir`.

## Config

In [1]:
import os
import sys

model_cls  = 'BertModel'
model_typ  = 'bert-base-uncased'
token_cls  = 'BertTokenizer'
max_seq    = 12
CLS        = "[CLS]"
SEP        = "[SEP]"
MASK       = "[MASK]"
CLS_IDX    = 0
layer_idxs = tuple(range(12))
input_text = "jim henson was a puppeteer"

pt_model_dir = "/home/ubuntu/.pytorch-pretrained-BERT-cache/{}".format(model_typ)
tf_bert_dir  = "/home/ubuntu/bert"

pt_vocab_file  = os.path.join(pt_model_dir, "vocab.txt")
pt_init_ckpt   = os.path.join(pt_model_dir, model_typ.replace("-", "_") + ".bin")
tf_model_dir   = os.path.join(pt_model_dir, 'tf')
tf_vocab_file  = os.path.join(tf_model_dir, "vocab.txt")
tf_init_ckpt   = os.path.join(tf_model_dir, model_typ.replace("-", "_") + ".ckpt")
tf_config_file = os.path.join(tf_model_dir, "bert_config.json")

if not os.path.isdir(tf_model_dir): 
    os.makedirs(tf_model_dir, exist_ok=True)

### Tokenization

In [2]:
def tokenize(text, tokenizer):
    text = text.strip().lower()
    tok_ids = tokenizer.tokenize(text)
    if len(tok_ids) > max_seq - 2:
        tok_ids = tok_ids[:max_seq - 2]
    tok_ids.insert(CLS_IDX, CLS)
    tok_ids.append(SEP)
    input_ids = tokenizer.convert_tokens_to_ids(tok_ids)
    mask_ids = [1] * len(input_ids)
    seg_ids = [0] * len(input_ids)
    padding = [0] * (max_seq - len(input_ids))
    input_ids += padding
    mask_ids += padding
    seg_ids += padding
    return input_ids, mask_ids, seg_ids

## Pytorch execution

In [3]:
import numpy as np
import torch
from pytorch_pretrained_bert import (BertConfig,
                                     BertModel, 
                                     BertTokenizer, 
                                     BertForSequenceClassification)

# Save Vocab
pt_tokenizer = BertTokenizer.from_pretrained(
    pretrained_model_name_or_path=model_typ, 
    cache_dir=pt_model_dir)
pt_tokenizer.save_vocabulary(pt_model_dir)
pt_tokenizer.save_vocabulary(tf_model_dir)

# Save Model
pt_model = BertModel.from_pretrained(
    pretrained_model_name_or_path=model_typ, 
    cache_dir=pt_model_dir).to('cpu')
pt_model.eval()
pt_model.config.hidden_dropout_prob = 0.0
pt_model.config.attention_probs_dropout_prob = 0.0
pt_model.config.to_json_file(tf_config_file)
torch.save(pt_model.state_dict(), pt_init_ckpt)

# Inputs
input_ids_pt, mask_ids_pt, seg_ids_pt = tokenize(input_text, pt_tokenizer)

# PT Embedding
tok_tensor = torch.tensor(input_ids_pt).to('cpu').unsqueeze(0)
seg_tensor = torch.tensor(seg_ids_pt).to('cpu').unsqueeze(0)
msk_tensor = torch.tensor(mask_ids_pt).to('cpu').unsqueeze(0)
attn_blks, nsp_logits = pt_model(tok_tensor, seg_tensor, msk_tensor)
pt_embedding = nsp_logits.detach().numpy() 
print("Pytorch embedding shape: {}".format(pt_embedding.shape))

100%|██████████| 231508/231508 [00:00<00:00, 41092464.26B/s]
100%|██████████| 407873900/407873900 [00:07<00:00, 58092479.52B/s]


Pytorch embedding shape: (1, 768)


## Pytorch &rarr; Tensorflow conversion

In [4]:
from pytorch_pretrained_bert.convert_pytorch_checkpoint_to_tf import main

main([
    '--model_name', model_typ, 
    '--pytorch_model_path', pt_init_ckpt,
    '--tf_cache_dir', tf_model_dir,
    '--cache_dir', pt_model_dir
])

Instructions for updating:
Colocations handled automatically by placer.
bert/embeddings/word_embeddings                             initialized
bert/embeddings/position_embeddings                         initialized
bert/embeddings/token_type_embeddings                       initialized
bert/embeddings/LayerNorm/gamma                             initialized
bert/embeddings/LayerNorm/beta                              initialized
bert/encoder/layer_0/attention/self/query/kernel            initialized
bert/encoder/layer_0/attention/self/query/bias              initialized
bert/encoder/layer_0/attention/self/key/kernel              initialized
bert/encoder/layer_0/attention/self/key/bias                initialized
bert/encoder/layer_0/attention/self/value/kernel            initialized
bert/encoder/layer_0/attention/self/value/bias              initialized
bert/encoder/layer_0/attention/output/dense/kernel          initialized
bert/encoder/layer_0/attention/output/dense/bias            init

## Tensorflow execution

In [5]:
import tensorflow as tf
sys.path.insert(0, tf_bert_dir)
import modeling
import tokenization

tf.reset_default_graph()

# Process text
tf_tokenizer = tokenization.FullTokenizer(vocab_file=tf_vocab_file)

# Graph inputs
input_ids_tf, mask_ids_tf, seg_ids_tf = tokenize(input_text, tf_tokenizer)
config = modeling.BertConfig.from_json_file(
    os.path.join(tf_model_dir, 'bert_config.json'))
input_tensor = tf.placeholder(
    dtype=tf.int32,
    shape=[1, None],
    name='input_ids')
mask_tensor = tf.placeholder(
    dtype=tf.int32,
    shape=[1, None],
    name='mask_ids')
seg_tensor = tf.placeholder(
    dtype=tf.int32,
    shape=[1, None],
    name='seg_ids')
tf_model = modeling.BertModel(
    config=config,
    is_training=False,
    input_ids=input_tensor,
    input_mask=mask_tensor,
    token_type_ids=seg_tensor,
    use_one_hot_embeddings=False)
output_layer = tf_model.get_pooled_output()

# Load tf model
session = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
vars_to_load = [v for v in tf.global_variables()]
session.run(tf.variables_initializer(var_list=vars_to_load))
saver = tf.train.Saver(vars_to_load)
saver.restore(session, save_path=tf_init_ckpt)

# TF Embedding
fetches = output_layer
feed_dict  = {
    input_tensor: [input_ids_tf],
    mask_tensor: [mask_ids_tf],
    seg_tensor: [seg_ids_tf]
}
tf_embedding = session.run(fetches=fetches, feed_dict=feed_dict)
print("Tensorflow embedding shape: {}".format(tf_embedding.shape))


For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.

Instructions for updating:
Use keras.layers.dense instead.
Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from /home/ubuntu/.pytorch-pretrained-BERT-cache/bert-base-uncased/tf/bert_base_uncased.ckpt
Tensorflow embedding shape: (1, 768)


## Compare Tokenization

In [6]:
print("TOKEN_IDS_PT: {}".format(input_ids_pt))
print("TOKEN_IDS_TF: {}".format(input_ids_tf))
print("SEG_IDS_PT:   {}".format(seg_ids_pt))
print("SEG_IDS_TF:   {}".format(seg_ids_tf))
print("MASK_IDS_PT:  {}".format(mask_ids_pt))
print("MASK_IDS_TF:  {}".format(mask_ids_tf))

TOKEN_IDS_PT: [101, 3958, 27227, 2001, 1037, 13997, 11510, 102, 0, 0, 0, 0]
TOKEN_IDS_TF: [101, 3958, 27227, 2001, 1037, 13997, 11510, 102, 0, 0, 0, 0]
SEG_IDS_PT:   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
SEG_IDS_TF:   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
MASK_IDS_PT:  [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0]
MASK_IDS_TF:  [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0]


## Compare Model Weights

In [7]:
tensors_to_transopse = (
    "dense.weight",
    "attention.self.query",
    "attention.self.key",
    "attention.self.value"
)
var_map = (
    ('layer.', 'layer_'),
    ('word_embeddings.weight', 'word_embeddings'),
    ('position_embeddings.weight', 'position_embeddings'),
    ('token_type_embeddings.weight', 'token_type_embeddings'),
    ('.', '/'),
    ('LayerNorm/weight', 'LayerNorm/gamma'),
    ('LayerNorm/bias', 'LayerNorm/beta'),
    ('weight', 'kernel')
)

def to_tf_var_name(name:str):
    for patt, repl in iter(var_map):
        name = name.replace(patt, repl)
    return 'bert/{}'.format(name)

tf_vars = {v.name: session.run(fetches=v) for v in tf.global_variables()}
pt_vars = {}
for v, T in pt_model.state_dict().items():
    T = T.detach().numpy()
    if any([x in v for x in tensors_to_transopse]):
        T = T.T
    pt_vars.update({to_tf_var_name(v): T})

for var_name in tf_vars:
    
    pt = pt_vars[var_name.strip(":0")]
    tf = tf_vars[var_name]

    print(var_name.strip(":0"))
    
    # Assert equivalence
    print("|sum(pt_wts - tf_wts)| = {}".format(
        np.abs(np.sum(pt - tf, keepdims=False))
    ))
    assert not np.sum(pt - tf, keepdims=False)
    
    if len(pt.shape) == 2:
        print("PT: shape: {0} values: {1}".format(pt.shape, pt[0, :5]))
        print("TF: shape: {0} values: {1}".format(tf.shape, tf[0, :5]))
    else:
        print("PT: shape: {0} values: {1}".format(pt.shape, pt[:5]))
        print("TF: shape: {0} values: {1}".format(tf.shape, tf[:5]))
    print()

bert/embeddings/word_embeddings
|sum(pt_wts - tf_wts)| = 0.0
PT: shape: (30522, 768) values: [-0.01018257 -0.06154883 -0.02649689 -0.0420608   0.00116716]
TF: shape: (30522, 768) values: [-0.01018257 -0.06154883 -0.02649689 -0.0420608   0.00116716]

bert/embeddings/token_type_embeddings
|sum(pt_wts - tf_wts)| = 0.0
PT: shape: (2, 768) values: [0.00043164 0.01098826 0.00370439 0.00150542 0.00057812]
TF: shape: (2, 768) values: [0.00043164 0.01098826 0.00370439 0.00150542 0.00057812]

bert/embeddings/position_embeddings
|sum(pt_wts - tf_wts)| = 0.0
PT: shape: (512, 768) values: [ 0.01750538 -0.02563101 -0.03664156 -0.02528613  0.00797095]
TF: shape: (512, 768) values: [ 0.01750538 -0.02563101 -0.03664156 -0.02528613  0.00797095]

bert/embeddings/LayerNorm/beta
|sum(pt_wts - tf_wts)| = 0.0
PT: shape: (768,) values: [-0.02591471 -0.0195513   0.02423946  0.08904593 -0.06281059]
TF: shape: (768,) values: [-0.02591471 -0.0195513   0.02423946  0.08904593 -0.06281059]

bert/embeddings/LayerNorm

## Compare Layer-12 Projections

In [8]:
# Mean Squared Error (MSE) between last projection of each model
MSE = np.mean((pt_embedding - tf_embedding) ** 2, keepdims=False)
print("MSE: {}".format(MSE))
print("PT-values: {}".format(pt_embedding[0, :5]))
print("TF-values: {}".format(tf_embedding[0, :5]))

MSE: 2.7155439966009e-05
PT-values: [-0.876663   -0.41088238 -0.12200808  0.44941     0.19445966]
TF-values: [-0.8742865  -0.40621698 -0.10585472  0.444904    0.1825743 ]
