mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Add POS tagging and Phrase chunking token classification examples (#6457)
* Add more token classification examples * POS tagging example * Phrase chunking example * PR review fixes * Add conllu to third party list (used in token classification examples)
This commit is contained in:
parent
f51161e230
commit
eda07efaa5
@ -15,3 +15,4 @@ pandas
|
|||||||
nlp
|
nlp
|
||||||
fire
|
fire
|
||||||
pytest
|
pytest
|
||||||
|
conllu
|
1
examples/token-classification/run.sh
Normal file → Executable file
1
examples/token-classification/run.sh
Normal file → Executable file
@ -18,6 +18,7 @@ export SAVE_STEPS=750
|
|||||||
export SEED=1
|
export SEED=1
|
||||||
|
|
||||||
python3 run_ner.py \
|
python3 run_ner.py \
|
||||||
|
--task_type NER \
|
||||||
--data_dir . \
|
--data_dir . \
|
||||||
--labels ./labels.txt \
|
--labels ./labels.txt \
|
||||||
--model_name_or_path $BERT_MODEL \
|
--model_name_or_path $BERT_MODEL \
|
||||||
|
37
examples/token-classification/run_chunk.sh
Executable file
37
examples/token-classification/run_chunk.sh
Executable file
@ -0,0 +1,37 @@
|
|||||||
|
if ! [ -f ./dev.txt ]; then
|
||||||
|
echo "Downloading CONLL2003 dev dataset...."
|
||||||
|
curl -L -o ./dev.txt 'https://github.com/davidsbatista/NER-datasets/raw/master/CONLL2003/valid.txt'
|
||||||
|
fi
|
||||||
|
|
||||||
|
if ! [ -f ./test.txt ]; then
|
||||||
|
echo "Downloading CONLL2003 test dataset...."
|
||||||
|
curl -L -o ./test.txt 'https://github.com/davidsbatista/NER-datasets/raw/master/CONLL2003/test.txt'
|
||||||
|
fi
|
||||||
|
|
||||||
|
if ! [ -f ./train.txt ]; then
|
||||||
|
echo "Downloading CONLL2003 train dataset...."
|
||||||
|
curl -L -o ./train.txt 'https://github.com/davidsbatista/NER-datasets/raw/master/CONLL2003/train.txt'
|
||||||
|
fi
|
||||||
|
|
||||||
|
export MAX_LENGTH=200
|
||||||
|
export BERT_MODEL=bert-base-uncased
|
||||||
|
export OUTPUT_DIR=chunker-model
|
||||||
|
export BATCH_SIZE=32
|
||||||
|
export NUM_EPOCHS=3
|
||||||
|
export SAVE_STEPS=750
|
||||||
|
export SEED=1
|
||||||
|
|
||||||
|
python3 run_ner.py \
|
||||||
|
--task_type Chunk \
|
||||||
|
--data_dir . \
|
||||||
|
--model_name_or_path $BERT_MODEL \
|
||||||
|
--output_dir $OUTPUT_DIR \
|
||||||
|
--max_seq_length $MAX_LENGTH \
|
||||||
|
--num_train_epochs $NUM_EPOCHS \
|
||||||
|
--per_gpu_train_batch_size $BATCH_SIZE \
|
||||||
|
--save_steps $SAVE_STEPS \
|
||||||
|
--seed $SEED \
|
||||||
|
--do_train \
|
||||||
|
--do_eval \
|
||||||
|
--do_predict
|
||||||
|
|
@ -14,16 +14,15 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Fine-tuning the library models for named entity recognition on CoNLL-2003. """
|
""" Fine-tuning the library models for named entity recognition on CoNLL-2003. """
|
||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from importlib import import_module
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from seqeval.metrics import f1_score, precision_score, recall_score
|
from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@ -36,7 +35,7 @@ from transformers import (
|
|||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
from utils_ner import NerDataset, Split, get_labels
|
from utils_ner import Split, TokenClassificationDataset, TokenClassificationTask
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -54,6 +53,9 @@ class ModelArguments:
|
|||||||
config_name: Optional[str] = field(
|
config_name: Optional[str] = field(
|
||||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||||
)
|
)
|
||||||
|
task_type: Optional[str] = field(
|
||||||
|
default="NER", metadata={"help": "Task type to fine tune in training (e.g. NER, POS, etc)"}
|
||||||
|
)
|
||||||
tokenizer_name: Optional[str] = field(
|
tokenizer_name: Optional[str] = field(
|
||||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||||
)
|
)
|
||||||
@ -113,6 +115,16 @@ def main():
|
|||||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
|
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
module = import_module("tasks")
|
||||||
|
try:
|
||||||
|
token_classification_task_clazz = getattr(module, model_args.task_type)
|
||||||
|
token_classification_task: TokenClassificationTask = token_classification_task_clazz()
|
||||||
|
except AttributeError:
|
||||||
|
raise ValueError(
|
||||||
|
f"Task {model_args.task_type} needs to be defined as a TokenClassificationTask subclass in {module}. "
|
||||||
|
f"Available tasks classes are: {TokenClassificationTask.__subclasses__()}"
|
||||||
|
)
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
@ -133,7 +145,7 @@ def main():
|
|||||||
set_seed(training_args.seed)
|
set_seed(training_args.seed)
|
||||||
|
|
||||||
# Prepare CONLL-2003 task
|
# Prepare CONLL-2003 task
|
||||||
labels = get_labels(data_args.labels)
|
labels = token_classification_task.get_labels(data_args.labels)
|
||||||
label_map: Dict[int, str] = {i: label for i, label in enumerate(labels)}
|
label_map: Dict[int, str] = {i: label for i, label in enumerate(labels)}
|
||||||
num_labels = len(labels)
|
num_labels = len(labels)
|
||||||
|
|
||||||
@ -164,7 +176,8 @@ def main():
|
|||||||
|
|
||||||
# Get datasets
|
# Get datasets
|
||||||
train_dataset = (
|
train_dataset = (
|
||||||
NerDataset(
|
TokenClassificationDataset(
|
||||||
|
token_classification_task=token_classification_task,
|
||||||
data_dir=data_args.data_dir,
|
data_dir=data_args.data_dir,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
labels=labels,
|
labels=labels,
|
||||||
@ -177,7 +190,8 @@ def main():
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
eval_dataset = (
|
eval_dataset = (
|
||||||
NerDataset(
|
TokenClassificationDataset(
|
||||||
|
token_classification_task=token_classification_task,
|
||||||
data_dir=data_args.data_dir,
|
data_dir=data_args.data_dir,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
labels=labels,
|
labels=labels,
|
||||||
@ -209,6 +223,7 @@ def main():
|
|||||||
def compute_metrics(p: EvalPrediction) -> Dict:
|
def compute_metrics(p: EvalPrediction) -> Dict:
|
||||||
preds_list, out_label_list = align_predictions(p.predictions, p.label_ids)
|
preds_list, out_label_list = align_predictions(p.predictions, p.label_ids)
|
||||||
return {
|
return {
|
||||||
|
"accuracy_score": accuracy_score(out_label_list, preds_list),
|
||||||
"precision": precision_score(out_label_list, preds_list),
|
"precision": precision_score(out_label_list, preds_list),
|
||||||
"recall": recall_score(out_label_list, preds_list),
|
"recall": recall_score(out_label_list, preds_list),
|
||||||
"f1": f1_score(out_label_list, preds_list),
|
"f1": f1_score(out_label_list, preds_list),
|
||||||
@ -253,7 +268,8 @@ def main():
|
|||||||
|
|
||||||
# Predict
|
# Predict
|
||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
test_dataset = NerDataset(
|
test_dataset = TokenClassificationDataset(
|
||||||
|
token_classification_task=token_classification_task,
|
||||||
data_dir=data_args.data_dir,
|
data_dir=data_args.data_dir,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
labels=labels,
|
labels=labels,
|
||||||
@ -278,19 +294,7 @@ def main():
|
|||||||
if trainer.is_world_master():
|
if trainer.is_world_master():
|
||||||
with open(output_test_predictions_file, "w") as writer:
|
with open(output_test_predictions_file, "w") as writer:
|
||||||
with open(os.path.join(data_args.data_dir, "test.txt"), "r") as f:
|
with open(os.path.join(data_args.data_dir, "test.txt"), "r") as f:
|
||||||
example_id = 0
|
token_classification_task.write_predictions_to_file(writer, f, preds_list)
|
||||||
for line in f:
|
|
||||||
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
|
|
||||||
writer.write(line)
|
|
||||||
if not preds_list[example_id]:
|
|
||||||
example_id += 1
|
|
||||||
elif preds_list[example_id]:
|
|
||||||
output_line = line.split()[0] + " " + preds_list[example_id].pop(0) + "\n"
|
|
||||||
writer.write(output_line)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@ -2,15 +2,17 @@ import argparse
|
|||||||
import glob
|
import glob
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from argparse import Namespace
|
||||||
|
from importlib import import_module
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from seqeval.metrics import f1_score, precision_score, recall_score
|
from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
from torch.utils.data import DataLoader, TensorDataset
|
from torch.utils.data import DataLoader, TensorDataset
|
||||||
|
|
||||||
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
||||||
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
|
from utils_ner import TokenClassificationTask
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -24,10 +26,20 @@ class NERTransformer(BaseTransformer):
|
|||||||
mode = "token-classification"
|
mode = "token-classification"
|
||||||
|
|
||||||
def __init__(self, hparams):
|
def __init__(self, hparams):
|
||||||
self.labels = get_labels(hparams.labels)
|
if type(hparams) == dict:
|
||||||
num_labels = len(self.labels)
|
hparams = Namespace(**hparams)
|
||||||
|
module = import_module("tasks")
|
||||||
|
try:
|
||||||
|
token_classification_task_clazz = getattr(module, hparams.task_type)
|
||||||
|
self.token_classification_task: TokenClassificationTask = token_classification_task_clazz()
|
||||||
|
except AttributeError:
|
||||||
|
raise ValueError(
|
||||||
|
f"Task {hparams.task_type} needs to be defined as a TokenClassificationTask subclass in {module}. "
|
||||||
|
f"Available tasks classes are: {TokenClassificationTask.__subclasses__()}"
|
||||||
|
)
|
||||||
|
self.labels = self.token_classification_task.get_labels(hparams.labels)
|
||||||
self.pad_token_label_id = CrossEntropyLoss().ignore_index
|
self.pad_token_label_id = CrossEntropyLoss().ignore_index
|
||||||
super().__init__(hparams, num_labels, self.mode)
|
super().__init__(hparams, len(self.labels), self.mode)
|
||||||
|
|
||||||
def forward(self, **inputs):
|
def forward(self, **inputs):
|
||||||
return self.model(**inputs)
|
return self.model(**inputs)
|
||||||
@ -42,8 +54,8 @@ class NERTransformer(BaseTransformer):
|
|||||||
|
|
||||||
outputs = self(**inputs)
|
outputs = self(**inputs)
|
||||||
loss = outputs[0]
|
loss = outputs[0]
|
||||||
tensorboard_logs = {"loss": loss, "rate": self.lr_scheduler.get_last_lr()[-1]}
|
# tensorboard_logs = {"loss": loss, "rate": self.lr_scheduler.get_last_lr()[-1]}
|
||||||
return {"loss": loss, "log": tensorboard_logs}
|
return {"loss": loss}
|
||||||
|
|
||||||
def prepare_data(self):
|
def prepare_data(self):
|
||||||
"Called to initialize data. Use the call to construct features"
|
"Called to initialize data. Use the call to construct features"
|
||||||
@ -55,8 +67,8 @@ class NERTransformer(BaseTransformer):
|
|||||||
features = torch.load(cached_features_file)
|
features = torch.load(cached_features_file)
|
||||||
else:
|
else:
|
||||||
logger.info("Creating features from dataset file at %s", args.data_dir)
|
logger.info("Creating features from dataset file at %s", args.data_dir)
|
||||||
examples = read_examples_from_file(args.data_dir, mode)
|
examples = self.token_classification_task.read_examples_from_file(args.data_dir, mode)
|
||||||
features = convert_examples_to_features(
|
features = self.token_classification_task.convert_examples_to_features(
|
||||||
examples,
|
examples,
|
||||||
self.labels,
|
self.labels,
|
||||||
args.max_seq_length,
|
args.max_seq_length,
|
||||||
@ -74,7 +86,7 @@ class NERTransformer(BaseTransformer):
|
|||||||
logger.info("Saving features into cached file %s", cached_features_file)
|
logger.info("Saving features into cached file %s", cached_features_file)
|
||||||
torch.save(features, cached_features_file)
|
torch.save(features, cached_features_file)
|
||||||
|
|
||||||
def load_dataset(self, mode, batch_size):
|
def get_dataloader(self, mode: int, batch_size: int) -> DataLoader:
|
||||||
"Load datasets. Called after prepare data."
|
"Load datasets. Called after prepare data."
|
||||||
cached_features_file = self._feature_file(mode)
|
cached_features_file = self._feature_file(mode)
|
||||||
logger.info("Loading features from cached file %s", cached_features_file)
|
logger.info("Loading features from cached file %s", cached_features_file)
|
||||||
@ -124,6 +136,7 @@ class NERTransformer(BaseTransformer):
|
|||||||
|
|
||||||
results = {
|
results = {
|
||||||
"val_loss": val_loss_mean,
|
"val_loss": val_loss_mean,
|
||||||
|
"accuracy_score": accuracy_score(out_label_list, preds_list),
|
||||||
"precision": precision_score(out_label_list, preds_list),
|
"precision": precision_score(out_label_list, preds_list),
|
||||||
"recall": recall_score(out_label_list, preds_list),
|
"recall": recall_score(out_label_list, preds_list),
|
||||||
"f1": f1_score(out_label_list, preds_list),
|
"f1": f1_score(out_label_list, preds_list),
|
||||||
@ -154,6 +167,9 @@ class NERTransformer(BaseTransformer):
|
|||||||
def add_model_specific_args(parser, root_dir):
|
def add_model_specific_args(parser, root_dir):
|
||||||
# Add NER specific options
|
# Add NER specific options
|
||||||
BaseTransformer.add_model_specific_args(parser, root_dir)
|
BaseTransformer.add_model_specific_args(parser, root_dir)
|
||||||
|
parser.add_argument(
|
||||||
|
"--task_type", default="NER", type=str, help="Task type to fine tune in training (e.g. NER, POS, etc)"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_seq_length",
|
"--max_seq_length",
|
||||||
default=128,
|
default=128,
|
||||||
|
37
examples/token-classification/run_pos.sh
Executable file
37
examples/token-classification/run_pos.sh
Executable file
@ -0,0 +1,37 @@
|
|||||||
|
if ! [ -f ./dev.txt ]; then
|
||||||
|
echo "Download dev dataset...."
|
||||||
|
curl -L -o ./dev.txt 'https://github.com/UniversalDependencies/UD_English-EWT/raw/master/en_ewt-ud-dev.conllu'
|
||||||
|
fi
|
||||||
|
|
||||||
|
if ! [ -f ./test.txt ]; then
|
||||||
|
echo "Download test dataset...."
|
||||||
|
curl -L -o ./test.txt 'https://github.com/UniversalDependencies/UD_English-EWT/raw/master/en_ewt-ud-test.conllu'
|
||||||
|
fi
|
||||||
|
|
||||||
|
if ! [ -f ./train.txt ]; then
|
||||||
|
echo "Download train dataset...."
|
||||||
|
curl -L -o ./train.txt 'https://github.com/UniversalDependencies/UD_English-EWT/raw/master/en_ewt-ud-train.conllu'
|
||||||
|
fi
|
||||||
|
|
||||||
|
export MAX_LENGTH=200
|
||||||
|
export BERT_MODEL=bert-base-uncased
|
||||||
|
export OUTPUT_DIR=postagger-model
|
||||||
|
export BATCH_SIZE=32
|
||||||
|
export NUM_EPOCHS=3
|
||||||
|
export SAVE_STEPS=750
|
||||||
|
export SEED=1
|
||||||
|
|
||||||
|
python3 run_ner.py \
|
||||||
|
--task_type POS \
|
||||||
|
--data_dir . \
|
||||||
|
--model_name_or_path $BERT_MODEL \
|
||||||
|
--output_dir $OUTPUT_DIR \
|
||||||
|
--max_seq_length $MAX_LENGTH \
|
||||||
|
--num_train_epochs $NUM_EPOCHS \
|
||||||
|
--per_gpu_train_batch_size $BATCH_SIZE \
|
||||||
|
--save_steps $SAVE_STEPS \
|
||||||
|
--seed $SEED \
|
||||||
|
--do_train \
|
||||||
|
--do_eval \
|
||||||
|
--do_predict
|
||||||
|
|
39
examples/token-classification/run_pos_pl.sh
Executable file
39
examples/token-classification/run_pos_pl.sh
Executable file
@ -0,0 +1,39 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
if ! [ -f ./dev.txt ]; then
|
||||||
|
echo "Download dev dataset...."
|
||||||
|
curl -L -o ./dev.txt 'https://github.com/UniversalDependencies/UD_English-EWT/raw/master/en_ewt-ud-dev.conllu'
|
||||||
|
fi
|
||||||
|
|
||||||
|
if ! [ -f ./test.txt ]; then
|
||||||
|
echo "Download test dataset...."
|
||||||
|
curl -L -o ./test.txt 'https://github.com/UniversalDependencies/UD_English-EWT/raw/master/en_ewt-ud-test.conllu'
|
||||||
|
fi
|
||||||
|
|
||||||
|
if ! [ -f ./train.txt ]; then
|
||||||
|
echo "Download train dataset...."
|
||||||
|
curl -L -o ./train.txt 'https://github.com/UniversalDependencies/UD_English-EWT/raw/master/en_ewt-ud-train.conllu'
|
||||||
|
fi
|
||||||
|
|
||||||
|
export MAX_LENGTH=200
|
||||||
|
export BERT_MODEL=bert-base-uncased
|
||||||
|
export OUTPUT_DIR=postagger-model
|
||||||
|
export BATCH_SIZE=32
|
||||||
|
export NUM_EPOCHS=3
|
||||||
|
export SAVE_STEPS=750
|
||||||
|
export SEED=1
|
||||||
|
|
||||||
|
|
||||||
|
# Add parent directory to python path to access lightning_base.py
|
||||||
|
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||||
|
|
||||||
|
python3 run_pl_ner.py --data_dir ./ \
|
||||||
|
--task_type POS \
|
||||||
|
--model_name_or_path $BERT_MODEL \
|
||||||
|
--output_dir $OUTPUT_DIR \
|
||||||
|
--max_seq_length $MAX_LENGTH \
|
||||||
|
--num_train_epochs $NUM_EPOCHS \
|
||||||
|
--train_batch_size $BATCH_SIZE \
|
||||||
|
--seed $SEED \
|
||||||
|
--gpus 1 \
|
||||||
|
--do_train \
|
||||||
|
--do_predict
|
163
examples/token-classification/tasks.py
Normal file
163
examples/token-classification/tasks.py
Normal file
@ -0,0 +1,163 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import List, TextIO, Union
|
||||||
|
|
||||||
|
from conllu import parse_incr
|
||||||
|
|
||||||
|
from utils_ner import InputExample, Split, TokenClassificationTask
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class NER(TokenClassificationTask):
|
||||||
|
def __init__(self, label_idx=-1):
|
||||||
|
# in NER datasets, the last column is usually reserved for NER label
|
||||||
|
self.label_idx = label_idx
|
||||||
|
|
||||||
|
def read_examples_from_file(self, data_dir, mode: Union[Split, str]) -> List[InputExample]:
|
||||||
|
if isinstance(mode, Split):
|
||||||
|
mode = mode.value
|
||||||
|
file_path = os.path.join(data_dir, f"{mode}.txt")
|
||||||
|
guid_index = 1
|
||||||
|
examples = []
|
||||||
|
with open(file_path, encoding="utf-8") as f:
|
||||||
|
words = []
|
||||||
|
labels = []
|
||||||
|
for line in f:
|
||||||
|
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
|
||||||
|
if words:
|
||||||
|
examples.append(InputExample(guid=f"{mode}-{guid_index}", words=words, labels=labels))
|
||||||
|
guid_index += 1
|
||||||
|
words = []
|
||||||
|
labels = []
|
||||||
|
else:
|
||||||
|
splits = line.split(" ")
|
||||||
|
words.append(splits[0])
|
||||||
|
if len(splits) > 1:
|
||||||
|
labels.append(splits[self.label_idx].replace("\n", ""))
|
||||||
|
else:
|
||||||
|
# Examples could have no label for mode = "test"
|
||||||
|
labels.append("O")
|
||||||
|
if words:
|
||||||
|
examples.append(InputExample(guid=f"{mode}-{guid_index}", words=words, labels=labels))
|
||||||
|
return examples
|
||||||
|
|
||||||
|
def write_predictions_to_file(self, writer: TextIO, test_input_reader: TextIO, preds_list: List):
|
||||||
|
example_id = 0
|
||||||
|
for line in test_input_reader:
|
||||||
|
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
|
||||||
|
writer.write(line)
|
||||||
|
if not preds_list[example_id]:
|
||||||
|
example_id += 1
|
||||||
|
elif preds_list[example_id]:
|
||||||
|
output_line = line.split()[0] + " " + preds_list[example_id].pop(0) + "\n"
|
||||||
|
writer.write(output_line)
|
||||||
|
else:
|
||||||
|
logger.warning("Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0])
|
||||||
|
|
||||||
|
def get_labels(self, path: str) -> List[str]:
|
||||||
|
if path:
|
||||||
|
with open(path, "r") as f:
|
||||||
|
labels = f.read().splitlines()
|
||||||
|
if "O" not in labels:
|
||||||
|
labels = ["O"] + labels
|
||||||
|
return labels
|
||||||
|
else:
|
||||||
|
return ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC"]
|
||||||
|
|
||||||
|
|
||||||
|
class Chunk(NER):
|
||||||
|
def __init__(self):
|
||||||
|
# in CONLL2003 dataset chunk column is second-to-last
|
||||||
|
super().__init__(label_idx=-2)
|
||||||
|
|
||||||
|
def get_labels(self, path: str) -> List[str]:
|
||||||
|
if path:
|
||||||
|
with open(path, "r") as f:
|
||||||
|
labels = f.read().splitlines()
|
||||||
|
if "O" not in labels:
|
||||||
|
labels = ["O"] + labels
|
||||||
|
return labels
|
||||||
|
else:
|
||||||
|
return [
|
||||||
|
"O",
|
||||||
|
"B-ADVP",
|
||||||
|
"B-INTJ",
|
||||||
|
"B-LST",
|
||||||
|
"B-PRT",
|
||||||
|
"B-NP",
|
||||||
|
"B-SBAR",
|
||||||
|
"B-VP",
|
||||||
|
"B-ADJP",
|
||||||
|
"B-CONJP",
|
||||||
|
"B-PP",
|
||||||
|
"I-ADVP",
|
||||||
|
"I-INTJ",
|
||||||
|
"I-LST",
|
||||||
|
"I-PRT",
|
||||||
|
"I-NP",
|
||||||
|
"I-SBAR",
|
||||||
|
"I-VP",
|
||||||
|
"I-ADJP",
|
||||||
|
"I-CONJP",
|
||||||
|
"I-PP",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class POS(TokenClassificationTask):
|
||||||
|
def read_examples_from_file(self, data_dir, mode: Union[Split, str]) -> List[InputExample]:
|
||||||
|
if isinstance(mode, Split):
|
||||||
|
mode = mode.value
|
||||||
|
file_path = os.path.join(data_dir, f"{mode}.txt")
|
||||||
|
guid_index = 1
|
||||||
|
examples = []
|
||||||
|
|
||||||
|
with open(file_path, encoding="utf-8") as f:
|
||||||
|
for sentence in parse_incr(f):
|
||||||
|
words = []
|
||||||
|
labels = []
|
||||||
|
for token in sentence:
|
||||||
|
words.append(token["form"])
|
||||||
|
labels.append(token["upos"])
|
||||||
|
assert len(words) == len(labels)
|
||||||
|
if words:
|
||||||
|
examples.append(InputExample(guid=f"{mode}-{guid_index}", words=words, labels=labels))
|
||||||
|
guid_index += 1
|
||||||
|
return examples
|
||||||
|
|
||||||
|
def write_predictions_to_file(self, writer: TextIO, test_input_reader: TextIO, preds_list: List):
|
||||||
|
example_id = 0
|
||||||
|
for sentence in parse_incr(test_input_reader):
|
||||||
|
s_p = preds_list[example_id]
|
||||||
|
out = ""
|
||||||
|
for token in sentence:
|
||||||
|
out += f'{token["form"]} ({token["upos"]}|{s_p.pop(0)}) '
|
||||||
|
out += "\n"
|
||||||
|
writer.write(out)
|
||||||
|
example_id += 1
|
||||||
|
|
||||||
|
def get_labels(self, path: str) -> List[str]:
|
||||||
|
if path:
|
||||||
|
with open(path, "r") as f:
|
||||||
|
return f.read().splitlines()
|
||||||
|
else:
|
||||||
|
return [
|
||||||
|
"ADJ",
|
||||||
|
"ADP",
|
||||||
|
"ADV",
|
||||||
|
"AUX",
|
||||||
|
"CCONJ",
|
||||||
|
"DET",
|
||||||
|
"INTJ",
|
||||||
|
"NOUN",
|
||||||
|
"NUM",
|
||||||
|
"PART",
|
||||||
|
"PRON",
|
||||||
|
"PROPN",
|
||||||
|
"PUNCT",
|
||||||
|
"SCONJ",
|
||||||
|
"SYM",
|
||||||
|
"VERB",
|
||||||
|
"X",
|
||||||
|
]
|
@ -66,12 +66,148 @@ class Split(Enum):
|
|||||||
test = "test"
|
test = "test"
|
||||||
|
|
||||||
|
|
||||||
|
class TokenClassificationTask:
|
||||||
|
def read_examples_from_file(self, data_dir, mode: Union[Split, str]) -> List[InputExample]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_labels(self, path: str) -> List[str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def convert_examples_to_features(
|
||||||
|
self,
|
||||||
|
examples: List[InputExample],
|
||||||
|
label_list: List[str],
|
||||||
|
max_seq_length: int,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
cls_token_at_end=False,
|
||||||
|
cls_token="[CLS]",
|
||||||
|
cls_token_segment_id=1,
|
||||||
|
sep_token="[SEP]",
|
||||||
|
sep_token_extra=False,
|
||||||
|
pad_on_left=False,
|
||||||
|
pad_token=0,
|
||||||
|
pad_token_segment_id=0,
|
||||||
|
pad_token_label_id=-100,
|
||||||
|
sequence_a_segment_id=0,
|
||||||
|
mask_padding_with_zero=True,
|
||||||
|
) -> List[InputFeatures]:
|
||||||
|
""" Loads a data file into a list of `InputFeatures`
|
||||||
|
`cls_token_at_end` define the location of the CLS token:
|
||||||
|
- False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
|
||||||
|
- True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
|
||||||
|
`cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
|
||||||
|
"""
|
||||||
|
# TODO clean up all this to leverage built-in features of tokenizers
|
||||||
|
|
||||||
|
label_map = {label: i for i, label in enumerate(label_list)}
|
||||||
|
|
||||||
|
features = []
|
||||||
|
for (ex_index, example) in enumerate(examples):
|
||||||
|
if ex_index % 10_000 == 0:
|
||||||
|
logger.info("Writing example %d of %d", ex_index, len(examples))
|
||||||
|
|
||||||
|
tokens = []
|
||||||
|
label_ids = []
|
||||||
|
for word, label in zip(example.words, example.labels):
|
||||||
|
word_tokens = tokenizer.tokenize(word)
|
||||||
|
|
||||||
|
# bert-base-multilingual-cased sometimes output "nothing ([]) when calling tokenize with just a space.
|
||||||
|
if len(word_tokens) > 0:
|
||||||
|
tokens.extend(word_tokens)
|
||||||
|
# Use the real label id for the first token of the word, and padding ids for the remaining tokens
|
||||||
|
label_ids.extend([label_map[label]] + [pad_token_label_id] * (len(word_tokens) - 1))
|
||||||
|
|
||||||
|
# Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa.
|
||||||
|
special_tokens_count = tokenizer.num_special_tokens_to_add()
|
||||||
|
if len(tokens) > max_seq_length - special_tokens_count:
|
||||||
|
tokens = tokens[: (max_seq_length - special_tokens_count)]
|
||||||
|
label_ids = label_ids[: (max_seq_length - special_tokens_count)]
|
||||||
|
|
||||||
|
# The convention in BERT is:
|
||||||
|
# (a) For sequence pairs:
|
||||||
|
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
|
||||||
|
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
|
||||||
|
# (b) For single sequences:
|
||||||
|
# tokens: [CLS] the dog is hairy . [SEP]
|
||||||
|
# type_ids: 0 0 0 0 0 0 0
|
||||||
|
#
|
||||||
|
# Where "type_ids" are used to indicate whether this is the first
|
||||||
|
# sequence or the second sequence. The embedding vectors for `type=0` and
|
||||||
|
# `type=1` were learned during pre-training and are added to the wordpiece
|
||||||
|
# embedding vector (and position vector). This is not *strictly* necessary
|
||||||
|
# since the [SEP] token unambiguously separates the sequences, but it makes
|
||||||
|
# it easier for the model to learn the concept of sequences.
|
||||||
|
#
|
||||||
|
# For classification tasks, the first vector (corresponding to [CLS]) is
|
||||||
|
# used as as the "sentence vector". Note that this only makes sense because
|
||||||
|
# the entire model is fine-tuned.
|
||||||
|
tokens += [sep_token]
|
||||||
|
label_ids += [pad_token_label_id]
|
||||||
|
if sep_token_extra:
|
||||||
|
# roberta uses an extra separator b/w pairs of sentences
|
||||||
|
tokens += [sep_token]
|
||||||
|
label_ids += [pad_token_label_id]
|
||||||
|
segment_ids = [sequence_a_segment_id] * len(tokens)
|
||||||
|
|
||||||
|
if cls_token_at_end:
|
||||||
|
tokens += [cls_token]
|
||||||
|
label_ids += [pad_token_label_id]
|
||||||
|
segment_ids += [cls_token_segment_id]
|
||||||
|
else:
|
||||||
|
tokens = [cls_token] + tokens
|
||||||
|
label_ids = [pad_token_label_id] + label_ids
|
||||||
|
segment_ids = [cls_token_segment_id] + segment_ids
|
||||||
|
|
||||||
|
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||||
|
|
||||||
|
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
||||||
|
# tokens are attended to.
|
||||||
|
input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
|
||||||
|
|
||||||
|
# Zero-pad up to the sequence length.
|
||||||
|
padding_length = max_seq_length - len(input_ids)
|
||||||
|
if pad_on_left:
|
||||||
|
input_ids = ([pad_token] * padding_length) + input_ids
|
||||||
|
input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
|
||||||
|
segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
|
||||||
|
label_ids = ([pad_token_label_id] * padding_length) + label_ids
|
||||||
|
else:
|
||||||
|
input_ids += [pad_token] * padding_length
|
||||||
|
input_mask += [0 if mask_padding_with_zero else 1] * padding_length
|
||||||
|
segment_ids += [pad_token_segment_id] * padding_length
|
||||||
|
label_ids += [pad_token_label_id] * padding_length
|
||||||
|
|
||||||
|
assert len(input_ids) == max_seq_length
|
||||||
|
assert len(input_mask) == max_seq_length
|
||||||
|
assert len(segment_ids) == max_seq_length
|
||||||
|
assert len(label_ids) == max_seq_length
|
||||||
|
|
||||||
|
if ex_index < 5:
|
||||||
|
logger.info("*** Example ***")
|
||||||
|
logger.info("guid: %s", example.guid)
|
||||||
|
logger.info("tokens: %s", " ".join([str(x) for x in tokens]))
|
||||||
|
logger.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
|
||||||
|
logger.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
|
||||||
|
logger.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
|
||||||
|
logger.info("label_ids: %s", " ".join([str(x) for x in label_ids]))
|
||||||
|
|
||||||
|
if "token_type_ids" not in tokenizer.model_input_names:
|
||||||
|
segment_ids = None
|
||||||
|
|
||||||
|
features.append(
|
||||||
|
InputFeatures(
|
||||||
|
input_ids=input_ids, attention_mask=input_mask, token_type_ids=segment_ids, label_ids=label_ids
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data.dataset import Dataset
|
from torch.utils.data.dataset import Dataset
|
||||||
|
|
||||||
class NerDataset(Dataset):
|
class TokenClassificationDataset(Dataset):
|
||||||
"""
|
"""
|
||||||
This will be superseded by a framework-agnostic approach
|
This will be superseded by a framework-agnostic approach
|
||||||
soon.
|
soon.
|
||||||
@ -84,6 +220,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
token_classification_task: TokenClassificationTask,
|
||||||
data_dir: str,
|
data_dir: str,
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
labels: List[str],
|
labels: List[str],
|
||||||
@ -107,9 +244,9 @@ if is_torch_available():
|
|||||||
self.features = torch.load(cached_features_file)
|
self.features = torch.load(cached_features_file)
|
||||||
else:
|
else:
|
||||||
logger.info(f"Creating features from dataset file at {data_dir}")
|
logger.info(f"Creating features from dataset file at {data_dir}")
|
||||||
examples = read_examples_from_file(data_dir, mode)
|
examples = token_classification_task.read_examples_from_file(data_dir, mode)
|
||||||
# TODO clean up all this to leverage built-in features of tokenizers
|
# TODO clean up all this to leverage built-in features of tokenizers
|
||||||
self.features = convert_examples_to_features(
|
self.features = token_classification_task.convert_examples_to_features(
|
||||||
examples,
|
examples,
|
||||||
labels,
|
labels,
|
||||||
max_seq_length,
|
max_seq_length,
|
||||||
@ -152,6 +289,7 @@ if is_tf_available():
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
token_classification_task: TokenClassificationTask,
|
||||||
data_dir: str,
|
data_dir: str,
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
labels: List[str],
|
labels: List[str],
|
||||||
@ -160,9 +298,9 @@ if is_tf_available():
|
|||||||
overwrite_cache=False,
|
overwrite_cache=False,
|
||||||
mode: Split = Split.train,
|
mode: Split = Split.train,
|
||||||
):
|
):
|
||||||
examples = read_examples_from_file(data_dir, mode)
|
examples = token_classification_task.read_examples_from_file(data_dir, mode)
|
||||||
# TODO clean up all this to leverage built-in features of tokenizers
|
# TODO clean up all this to leverage built-in features of tokenizers
|
||||||
self.features = convert_examples_to_features(
|
self.features = token_classification_task.convert_examples_to_features(
|
||||||
examples,
|
examples,
|
||||||
labels,
|
labels,
|
||||||
max_seq_length,
|
max_seq_length,
|
||||||
@ -230,171 +368,3 @@ if is_tf_available():
|
|||||||
|
|
||||||
def __getitem__(self, i) -> InputFeatures:
|
def __getitem__(self, i) -> InputFeatures:
|
||||||
return self.features[i]
|
return self.features[i]
|
||||||
|
|
||||||
|
|
||||||
def read_examples_from_file(data_dir, mode: Union[Split, str]) -> List[InputExample]:
|
|
||||||
if isinstance(mode, Split):
|
|
||||||
mode = mode.value
|
|
||||||
file_path = os.path.join(data_dir, f"{mode}.txt")
|
|
||||||
guid_index = 1
|
|
||||||
examples = []
|
|
||||||
with open(file_path, encoding="utf-8") as f:
|
|
||||||
words = []
|
|
||||||
labels = []
|
|
||||||
for line in f:
|
|
||||||
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
|
|
||||||
if words:
|
|
||||||
examples.append(InputExample(guid=f"{mode}-{guid_index}", words=words, labels=labels))
|
|
||||||
guid_index += 1
|
|
||||||
words = []
|
|
||||||
labels = []
|
|
||||||
else:
|
|
||||||
splits = line.split(" ")
|
|
||||||
words.append(splits[0])
|
|
||||||
if len(splits) > 1:
|
|
||||||
labels.append(splits[-1].replace("\n", ""))
|
|
||||||
else:
|
|
||||||
# Examples could have no label for mode = "test"
|
|
||||||
labels.append("O")
|
|
||||||
if words:
|
|
||||||
examples.append(InputExample(guid=f"{mode}-{guid_index}", words=words, labels=labels))
|
|
||||||
return examples
|
|
||||||
|
|
||||||
|
|
||||||
def convert_examples_to_features(
|
|
||||||
examples: List[InputExample],
|
|
||||||
label_list: List[str],
|
|
||||||
max_seq_length: int,
|
|
||||||
tokenizer: PreTrainedTokenizer,
|
|
||||||
cls_token_at_end=False,
|
|
||||||
cls_token="[CLS]",
|
|
||||||
cls_token_segment_id=1,
|
|
||||||
sep_token="[SEP]",
|
|
||||||
sep_token_extra=False,
|
|
||||||
pad_on_left=False,
|
|
||||||
pad_token=0,
|
|
||||||
pad_token_segment_id=0,
|
|
||||||
pad_token_label_id=-100,
|
|
||||||
sequence_a_segment_id=0,
|
|
||||||
mask_padding_with_zero=True,
|
|
||||||
) -> List[InputFeatures]:
|
|
||||||
""" Loads a data file into a list of `InputFeatures`
|
|
||||||
`cls_token_at_end` define the location of the CLS token:
|
|
||||||
- False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
|
|
||||||
- True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
|
|
||||||
`cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
|
|
||||||
"""
|
|
||||||
# TODO clean up all this to leverage built-in features of tokenizers
|
|
||||||
|
|
||||||
label_map = {label: i for i, label in enumerate(label_list)}
|
|
||||||
|
|
||||||
features = []
|
|
||||||
for (ex_index, example) in enumerate(examples):
|
|
||||||
if ex_index % 10_000 == 0:
|
|
||||||
logger.info("Writing example %d of %d", ex_index, len(examples))
|
|
||||||
|
|
||||||
tokens = []
|
|
||||||
label_ids = []
|
|
||||||
for word, label in zip(example.words, example.labels):
|
|
||||||
word_tokens = tokenizer.tokenize(word)
|
|
||||||
|
|
||||||
# bert-base-multilingual-cased sometimes output "nothing ([]) when calling tokenize with just a space.
|
|
||||||
if len(word_tokens) > 0:
|
|
||||||
tokens.extend(word_tokens)
|
|
||||||
# Use the real label id for the first token of the word, and padding ids for the remaining tokens
|
|
||||||
label_ids.extend([label_map[label]] + [pad_token_label_id] * (len(word_tokens) - 1))
|
|
||||||
|
|
||||||
# Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa.
|
|
||||||
special_tokens_count = tokenizer.num_special_tokens_to_add()
|
|
||||||
if len(tokens) > max_seq_length - special_tokens_count:
|
|
||||||
tokens = tokens[: (max_seq_length - special_tokens_count)]
|
|
||||||
label_ids = label_ids[: (max_seq_length - special_tokens_count)]
|
|
||||||
|
|
||||||
# The convention in BERT is:
|
|
||||||
# (a) For sequence pairs:
|
|
||||||
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
|
|
||||||
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
|
|
||||||
# (b) For single sequences:
|
|
||||||
# tokens: [CLS] the dog is hairy . [SEP]
|
|
||||||
# type_ids: 0 0 0 0 0 0 0
|
|
||||||
#
|
|
||||||
# Where "type_ids" are used to indicate whether this is the first
|
|
||||||
# sequence or the second sequence. The embedding vectors for `type=0` and
|
|
||||||
# `type=1` were learned during pre-training and are added to the wordpiece
|
|
||||||
# embedding vector (and position vector). This is not *strictly* necessary
|
|
||||||
# since the [SEP] token unambiguously separates the sequences, but it makes
|
|
||||||
# it easier for the model to learn the concept of sequences.
|
|
||||||
#
|
|
||||||
# For classification tasks, the first vector (corresponding to [CLS]) is
|
|
||||||
# used as as the "sentence vector". Note that this only makes sense because
|
|
||||||
# the entire model is fine-tuned.
|
|
||||||
tokens += [sep_token]
|
|
||||||
label_ids += [pad_token_label_id]
|
|
||||||
if sep_token_extra:
|
|
||||||
# roberta uses an extra separator b/w pairs of sentences
|
|
||||||
tokens += [sep_token]
|
|
||||||
label_ids += [pad_token_label_id]
|
|
||||||
segment_ids = [sequence_a_segment_id] * len(tokens)
|
|
||||||
|
|
||||||
if cls_token_at_end:
|
|
||||||
tokens += [cls_token]
|
|
||||||
label_ids += [pad_token_label_id]
|
|
||||||
segment_ids += [cls_token_segment_id]
|
|
||||||
else:
|
|
||||||
tokens = [cls_token] + tokens
|
|
||||||
label_ids = [pad_token_label_id] + label_ids
|
|
||||||
segment_ids = [cls_token_segment_id] + segment_ids
|
|
||||||
|
|
||||||
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
|
||||||
|
|
||||||
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
|
||||||
# tokens are attended to.
|
|
||||||
input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
|
|
||||||
|
|
||||||
# Zero-pad up to the sequence length.
|
|
||||||
padding_length = max_seq_length - len(input_ids)
|
|
||||||
if pad_on_left:
|
|
||||||
input_ids = ([pad_token] * padding_length) + input_ids
|
|
||||||
input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
|
|
||||||
segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
|
|
||||||
label_ids = ([pad_token_label_id] * padding_length) + label_ids
|
|
||||||
else:
|
|
||||||
input_ids += [pad_token] * padding_length
|
|
||||||
input_mask += [0 if mask_padding_with_zero else 1] * padding_length
|
|
||||||
segment_ids += [pad_token_segment_id] * padding_length
|
|
||||||
label_ids += [pad_token_label_id] * padding_length
|
|
||||||
|
|
||||||
assert len(input_ids) == max_seq_length
|
|
||||||
assert len(input_mask) == max_seq_length
|
|
||||||
assert len(segment_ids) == max_seq_length
|
|
||||||
assert len(label_ids) == max_seq_length
|
|
||||||
|
|
||||||
if ex_index < 5:
|
|
||||||
logger.info("*** Example ***")
|
|
||||||
logger.info("guid: %s", example.guid)
|
|
||||||
logger.info("tokens: %s", " ".join([str(x) for x in tokens]))
|
|
||||||
logger.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
|
|
||||||
logger.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
|
|
||||||
logger.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
|
|
||||||
logger.info("label_ids: %s", " ".join([str(x) for x in label_ids]))
|
|
||||||
|
|
||||||
if "token_type_ids" not in tokenizer.model_input_names:
|
|
||||||
segment_ids = None
|
|
||||||
|
|
||||||
features.append(
|
|
||||||
InputFeatures(
|
|
||||||
input_ids=input_ids, attention_mask=input_mask, token_type_ids=segment_ids, label_ids=label_ids
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return features
|
|
||||||
|
|
||||||
|
|
||||||
def get_labels(path: str) -> List[str]:
|
|
||||||
if path:
|
|
||||||
with open(path, "r") as f:
|
|
||||||
labels = f.read().splitlines()
|
|
||||||
if "O" not in labels:
|
|
||||||
labels = ["O"] + labels
|
|
||||||
return labels
|
|
||||||
else:
|
|
||||||
return ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC"]
|
|
||||||
|
Loading…
Reference in New Issue
Block a user