adding jupyter, updating extract features adding simple test file

This commit is contained in:
thomwolf 2018-11-02 14:25:21 +01:00
parent 844b2f0e6f
commit c9690e57f8
3 changed files with 318 additions and 29 deletions

View File

@ -0,0 +1,288 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# TensorFlow code"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-02T13:05:56.692585Z",
"start_time": "2018-11-02T13:05:55.699169Z"
}
},
"outputs": [],
"source": [
"from extract_features import *"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-02T13:18:23.944585Z",
"start_time": "2018-11-02T13:18:23.821309Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Example ***\n",
"INFO:tensorflow:unique_id: 0\n",
"INFO:tensorflow:tokens: [CLS] who was jim henson ? [SEP] jim henson was a puppet ##eer [SEP]\n",
"INFO:tensorflow:input_ids: 101 2040 2001 3958 27227 1029 102 3958 27227 2001 1037 13997 11510 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
"INFO:tensorflow:input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
"INFO:tensorflow:input_type_ids: 0 0 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n"
]
}
],
"source": [
"data_dir=\"/Users/thomaswolf/Documents/Thomas/Code/HF/BERT/data/glue_data/MRPC/\"\n",
"vocab_file=\"/Users/thomaswolf/Documents/Thomas/Code/HF/BERT/google_models/uncased_L-12_H-768_A-12/vocab.txt\"\n",
"bert_config_file=\"/Users/thomaswolf/Documents/Thomas/Code/HF/BERT/google_models/uncased_L-12_H-768_A-12/bert_config.json\"\n",
"init_checkpoint=\"/Users/thomaswolf/Documents/Thomas/Code/HF/BERT/google_models/uncased_L-12_H-768_A-12/bert_model.ckpt\"\n",
"max_seq_length=128\n",
"input_file=\"/Users/thomaswolf/Documents/Thomas/Code/HF/BERT/pytorch-pretrained-BERT/input.txt\"\n",
"\n",
"layer_indexes = [-1]\n",
"bert_config = modeling.BertConfig.from_json_file(bert_config_file)\n",
"tokenizer = tokenization.FullTokenizer(\n",
" vocab_file=vocab_file, do_lower_case=True)\n",
"examples = read_examples(input_file)\n",
"\n",
"features = convert_examples_to_features(\n",
" examples=examples, seq_length=max_seq_length, tokenizer=tokenizer)\n",
"unique_id_to_feature = {}\n",
"for feature in features:\n",
" unique_id_to_feature[feature.unique_id] = feature"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-02T13:18:24.802620Z",
"start_time": "2018-11-02T13:18:24.764474Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:Estimator's model_fn (<function model_fn_builder.<locals>.model_fn at 0x128feb7b8>) includes params argument, but params are not passed to Estimator.\n",
"WARNING:tensorflow:Using temporary folder as model directory: /var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmpp9hntmfs\n",
"INFO:tensorflow:Using config: {'_model_dir': '/var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmpp9hntmfs', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n",
"graph_options {\n",
" rewrite_options {\n",
" meta_optimizer_iterations: ONE\n",
" }\n",
"}\n",
", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x1263809e8>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_tpu_config': TPUConfig(iterations_per_loop=2, num_shards=1, num_cores_per_replica=None, per_host_input_for_training=3, tpu_job_name=None, initial_infeed_sleep_secs=None, input_partition_dims=None), '_cluster': None}\n",
"WARNING:tensorflow:Setting TPUConfig.num_shards==1 is an unsupported behavior. Please fix as soon as possible (leaving num_shards as None.\n",
"INFO:tensorflow:_TPUContext: eval_on_tpu True\n",
"WARNING:tensorflow:eval_on_tpu ignored because use_tpu is False.\n"
]
}
],
"source": [
"is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2\n",
"run_config = tf.contrib.tpu.RunConfig(\n",
" master=None,\n",
" tpu_config=tf.contrib.tpu.TPUConfig(\n",
" num_shards=1,\n",
" per_host_input_for_training=is_per_host))\n",
"\n",
"model_fn = model_fn_builder(\n",
" bert_config=bert_config,\n",
" init_checkpoint=init_checkpoint,\n",
" layer_indexes=layer_indexes,\n",
" use_tpu=False,\n",
" use_one_hot_embeddings=False)\n",
"\n",
"# If TPU is not available, this will fall back to normal Estimator on CPU\n",
"# or GPU.\n",
"estimator = tf.contrib.tpu.TPUEstimator(\n",
" use_tpu=False,\n",
" model_fn=model_fn,\n",
" config=run_config,\n",
" predict_batch_size=1)\n",
"\n",
"input_fn = input_fn_builder(\n",
" features=features, seq_length=max_seq_length)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-02T13:19:20.060587Z",
"start_time": "2018-11-02T13:19:14.804525Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Could not find trained model in model_dir: /var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmpp9hntmfs, running initialization to predict.\n",
"INFO:tensorflow:Calling model_fn.\n",
"INFO:tensorflow:Running infer on CPU\n",
"INFO:tensorflow:Done calling model_fn.\n",
"INFO:tensorflow:Graph was finalized.\n",
"INFO:tensorflow:Running local_init_op.\n",
"INFO:tensorflow:Done running local_init_op.\n",
"INFO:tensorflow:prediction_loop marked as finished\n",
"INFO:tensorflow:prediction_loop marked as finished\n"
]
}
],
"source": [
"all_out = []\n",
"for result in estimator.predict(input_fn, yield_single_examples=True):\n",
" unique_id = int(result[\"unique_id\"])\n",
" feature = unique_id_to_feature[unique_id]\n",
" output_json = collections.OrderedDict()\n",
" output_json[\"linex_index\"] = unique_id\n",
" all_features = []\n",
" for (i, token) in enumerate(feature.tokens):\n",
" all_layers = []\n",
" for (j, layer_index) in enumerate(layer_indexes):\n",
" layer_output = result[\"layer_output_%d\" % j]\n",
" layers = collections.OrderedDict()\n",
" layers[\"index\"] = layer_index\n",
" layers[\"values\"] = [\n",
" round(float(x), 6) for x in layer_output[i:(i + 1)].flat\n",
" ]\n",
" all_layers.append(layers)\n",
" features = collections.OrderedDict()\n",
" features[\"token\"] = token\n",
" features[\"layers\"] = all_layers\n",
" all_features.append(features)\n",
" output_json[\"features\"] = all_features\n",
" all_out.append(output_json)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-02T13:22:39.694206Z",
"start_time": "2018-11-02T13:22:39.663432Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"14"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(all_out)\n",
"len(all_out[0])\n",
"all_out[0].keys()\n",
"len(all_out[0]['features'])"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-02T13:23:05.752981Z",
"start_time": "2018-11-02T13:23:05.723891Z"
}
},
"outputs": [],
"source": [
"tensorflow_output = all_out[0]['features'][0]['layers'][0]['values']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# PyTorch code"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-02T13:24:27.644785Z",
"start_time": "2018-11-02T13:24:27.611996Z"
}
},
"outputs": [],
"source": [
"from extract_features_pytorch import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"hide_input": false,
"kernelspec": {
"display_name": "Python [conda env:bert]",
"language": "python",
"name": "conda-env-bert-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
},
"toc": {
"colors": {
"hover_highlight": "#DAA520",
"running_highlight": "#FF0000",
"selected_highlight": "#FFD700"
},
"moveMenuLeft": true,
"nav_menu": {
"height": "48px",
"width": "252px"
},
"navigate_menu": true,
"number_sections": true,
"sideBar": true,
"threshold": 4,
"toc_cell": false,
"toc_section_display": "block",
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -37,35 +37,6 @@ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(messa
level = logging.INFO)
logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--input_file", default=None, type=str, required=True)
parser.add_argument("--vocab_file", default=None, type=str, required=True,
help="The vocabulary file that the BERT model was trained on.")
parser.add_argument("--output_file", default=None, type=str, required=True)
parser.add_argument("--bert_config_file", default=None, type=str, required=True,
help="The config json file corresponding to the pre-trained BERT model. "
"This specifies the model architecture.")
parser.add_argument("--init_checkpoint", default=None, type=str, required=True,
help="Initial checkpoint (usually from a pre-trained BERT model).")
## Other parameters
parser.add_argument("--layers", default="-1,-2,-3,-4", type=str)
parser.add_argument("--max_seq_length", default=128, type=int,
help="The maximum total input sequence length after WordPiece tokenization. Sequences longer "
"than this will be truncated, and sequences shorter than this will be padded.")
parser.add_argument("--do_lower_case", default=True, action='store_true',
help="Whether to lower case the input text. Should be True for uncased "
"models and False for cased models.")
parser.add_argument("--batch_size", default=32, type=int, help="Batch size for predictions.")
parser.add_argument("--local_rank",
type=int,
default=-1,
help = "local_rank for distributed training on gpus")
args = parser.parse_args()
class InputExample(object):
@ -219,6 +190,35 @@ def read_examples(input_file):
def main():
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--input_file", default=None, type=str, required=True)
parser.add_argument("--vocab_file", default=None, type=str, required=True,
help="The vocabulary file that the BERT model was trained on.")
parser.add_argument("--output_file", default=None, type=str, required=True)
parser.add_argument("--bert_config_file", default=None, type=str, required=True,
help="The config json file corresponding to the pre-trained BERT model. "
"This specifies the model architecture.")
parser.add_argument("--init_checkpoint", default=None, type=str, required=True,
help="Initial checkpoint (usually from a pre-trained BERT model).")
## Other parameters
parser.add_argument("--layers", default="-1,-2,-3,-4", type=str)
parser.add_argument("--max_seq_length", default=128, type=int,
help="The maximum total input sequence length after WordPiece tokenization. Sequences longer "
"than this will be truncated, and sequences shorter than this will be padded.")
parser.add_argument("--do_lower_case", default=True, action='store_true',
help="Whether to lower case the input text. Should be True for uncased "
"models and False for cased models.")
parser.add_argument("--batch_size", default=32, type=int, help="Batch size for predictions.")
parser.add_argument("--local_rank",
type=int,
default=-1,
help = "local_rank for distributed training on gpus")
args = parser.parse_args()
if args.local_rank == -1 or args.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
n_gpu = torch.cuda.device_count()

1
input.txt Normal file
View File

@ -0,0 +1 @@
Who was Jim Henson ? ||| Jim Henson was a puppeteer