mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Merge remote-tracking branch 'refs/remotes/huggingface/master'
This commit is contained in:
commit
76f0d99f02
@ -58,7 +58,7 @@ Choose the right framework for every part of a model's lifetime
|
||||
| [Quick tour: Fine-tuning/usage scripts](#quick-tour-of-the-fine-tuningusage-scripts) | Using provided scripts: GLUE, SQuAD and Text generation |
|
||||
| [Migrating from pytorch-transformers to transformers](#Migrating-from-pytorch-transformers-to-transformers) | Migrating your code from pytorch-transformers to transformers |
|
||||
| [Migrating from pytorch-pretrained-bert to pytorch-transformers](#Migrating-from-pytorch-pretrained-bert-to-transformers) | Migrating your code from pytorch-pretrained-bert to transformers |
|
||||
| [Documentation][(v2.2.0/v2.2.1)](https://huggingface.co/transformers/v2.2.0) [(v2.1.1)](https://huggingface.co/transformers/v2.1.1) [(v2.0.0)](https://huggingface.co/transformers/v2.0.0) [(v1.2.0)](https://huggingface.co/transformers/v1.2.0) [(v1.1.0)](https://huggingface.co/transformers/v1.1.0) [(v1.0.0)](https://huggingface.co/transformers/v1.0.0) [(master)](https://huggingface.co/transformers) | Full API documentation and more |
|
||||
| [Documentation][(v2.2.0/v2.2.1/v2.2.2)](https://huggingface.co/transformers/v2.2.0) [(v2.1.1)](https://huggingface.co/transformers/v2.1.1) [(v2.0.0)](https://huggingface.co/transformers/v2.0.0) [(v1.2.0)](https://huggingface.co/transformers/v1.2.0) [(v1.1.0)](https://huggingface.co/transformers/v1.1.0) [(v1.0.0)](https://huggingface.co/transformers/v1.0.0) [(master)](https://huggingface.co/transformers) | Full API documentation and more |
|
||||
|
||||
## Installation
|
||||
|
||||
|
@ -26,7 +26,7 @@ author = u'huggingface'
|
||||
# The short X.Y version
|
||||
version = u''
|
||||
# The full version, including alpha/beta/rc tags
|
||||
release = u'2.2.1'
|
||||
release = u'2.2.2'
|
||||
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
@ -580,10 +580,16 @@ def main():
|
||||
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
||||
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs
|
||||
|
||||
if args.do_train:
|
||||
logger.info("Loading checkpoints saved during training for evaluation")
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
||||
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs
|
||||
else:
|
||||
logger.info("Loading checkpoint %s for evaluation", args.model_name_or_path)
|
||||
checkpoints = [args.model_name_or_path]
|
||||
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
|
||||
|
@ -29,7 +29,7 @@ And move all the stories to the same folder. We will refer as `$DATA_PATH` the p
|
||||
python run_summarization.py \
|
||||
--documents_dir $DATA_PATH \
|
||||
--summaries_output_dir $SUMMARIES_PATH \ # optional
|
||||
--to_cpu false \
|
||||
--no_cuda false \
|
||||
--batch_size 4 \
|
||||
--min_length 50 \
|
||||
--max_length 200 \
|
||||
@ -39,7 +39,7 @@ python run_summarization.py \
|
||||
--compute_rouge true
|
||||
```
|
||||
|
||||
The scripts executes on GPU if one is available and if `to_cpu` is not set to `true`. Inference on multiple GPUs is not suported yet. The ROUGE scores will be displayed in the console at the end of evaluation and written in a `rouge_scores.txt` file. The script takes 30 hours to compute with a single Tesla V100 GPU and a batch size of 10 (300,000 texts to summarize).
|
||||
The scripts executes on GPU if one is available and if `no_cuda` is not set to `true`. Inference on multiple GPUs is not suported yet. The ROUGE scores will be displayed in the console at the end of evaluation and written in a `rouge_scores.txt` file. The script takes 30 hours to compute with a single Tesla V100 GPU and a batch size of 10 (300,000 texts to summarize).
|
||||
|
||||
## Summarize any text
|
||||
|
||||
@ -49,7 +49,7 @@ Put the documents that you would like to summarize in a folder (the path to whic
|
||||
python run_summarization.py \
|
||||
--documents_dir $DATA_PATH \
|
||||
--summaries_output_dir $SUMMARIES_PATH \ # optional
|
||||
--to_cpu false \
|
||||
--no_cuda false \
|
||||
--batch_size 4 \
|
||||
--min_length 50 \
|
||||
--max_length 200 \
|
||||
|
2
setup.py
2
setup.py
@ -44,7 +44,7 @@ extras['all'] = [package for package in extras.values()]
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
version="2.2.1",
|
||||
version="2.2.2",
|
||||
author="Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Google AI Language Team Authors, Open AI team Authors, Facebook AI Authors, Carnegie Mellon University Authors",
|
||||
author_email="thomas@huggingface.co",
|
||||
description="State-of-the-art Natural Language Processing for TensorFlow 2.0 and PyTorch",
|
||||
|
@ -1,4 +1,4 @@
|
||||
__version__ = "2.2.1"
|
||||
__version__ = "2.2.2"
|
||||
|
||||
# Work around to update TensorFlow's absl.logging threshold which alters the
|
||||
# default Python logging output behavior when present.
|
||||
|
@ -19,8 +19,8 @@ class UserCommands(BaseTransformersCLICommand):
|
||||
list_parser.set_defaults(func=lambda args: ListObjsCommand(args))
|
||||
# upload
|
||||
upload_parser = parser.add_parser('upload')
|
||||
upload_parser.add_argument('file', type=str, help='Local filepath of the file to upload.')
|
||||
upload_parser.add_argument('--filename', type=str, default=None, help='Optional: override object filename on S3.')
|
||||
upload_parser.add_argument('path', type=str, help='Local path of the folder or individual file to upload.')
|
||||
upload_parser.add_argument('--filename', type=str, default=None, help='Optional: override individual object filename on S3.')
|
||||
upload_parser.set_defaults(func=lambda args: UploadCommand(args))
|
||||
|
||||
|
||||
@ -138,28 +138,57 @@ class ListObjsCommand(BaseUserCommand):
|
||||
|
||||
|
||||
class UploadCommand(BaseUserCommand):
|
||||
def walk_dir(self, rel_path):
|
||||
"""
|
||||
Recursively list all files in a folder.
|
||||
"""
|
||||
entries: List[os.DirEntry] = list(os.scandir(rel_path))
|
||||
files = [
|
||||
(
|
||||
os.path.join(os.getcwd(), f.path), # filepath
|
||||
f.path # filename
|
||||
)
|
||||
for f in entries if f.is_file()
|
||||
]
|
||||
for f in entries:
|
||||
if f.is_dir():
|
||||
files += self.walk_dir(f.path)
|
||||
return files
|
||||
|
||||
def run(self):
|
||||
token = HfFolder.get_token()
|
||||
if token is None:
|
||||
print("Not logged in")
|
||||
exit(1)
|
||||
filepath = os.path.join(os.getcwd(), self.args.file)
|
||||
filename = self.args.filename if self.args.filename is not None else os.path.basename(filepath)
|
||||
print(
|
||||
"About to upload file {} to S3 under filename {}".format(
|
||||
ANSI.bold(filepath), ANSI.bold(filename)
|
||||
local_path = os.path.abspath(self.args.path)
|
||||
if os.path.isdir(local_path):
|
||||
if self.args.filename is not None:
|
||||
raise ValueError("Cannot specify a filename override when uploading a folder.")
|
||||
rel_path = os.path.basename(local_path)
|
||||
files = self.walk_dir(rel_path)
|
||||
elif os.path.isfile(local_path):
|
||||
filename = self.args.filename if self.args.filename is not None else os.path.basename(local_path)
|
||||
files = [(local_path, filename)]
|
||||
else:
|
||||
raise ValueError("Not a valid file or directory: {}".format(local_path))
|
||||
|
||||
for filepath, filename in files:
|
||||
print(
|
||||
"About to upload file {} to S3 under filename {}".format(
|
||||
ANSI.bold(filepath), ANSI.bold(filename)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
choice = input("Proceed? [Y/n] ").lower()
|
||||
if not(choice == "" or choice == "y" or choice == "yes"):
|
||||
print("Abort")
|
||||
exit()
|
||||
print(
|
||||
ANSI.bold("Uploading... This might take a while if file is large")
|
||||
ANSI.bold("Uploading... This might take a while if files are large")
|
||||
)
|
||||
access_url = self._api.presign_and_upload(
|
||||
token=token, filename=filename, filepath=filepath
|
||||
)
|
||||
print("Your file now lives at:")
|
||||
print(access_url)
|
||||
for filepath, filename in files:
|
||||
access_url = self._api.presign_and_upload(
|
||||
token=token, filename=filename, filepath=filepath
|
||||
)
|
||||
print("Your file now lives at:")
|
||||
print(access_url)
|
||||
|
@ -133,7 +133,7 @@ def glue_convert_examples_to_features(examples, tokenizer,
|
||||
if is_tf_available() and is_tf_dataset:
|
||||
def gen():
|
||||
for ex in features:
|
||||
yield ({'input_ids': ex.input_ids,
|
||||
yield ({'input_ids': ex.input_ids,
|
||||
'attention_mask': ex.attention_mask,
|
||||
'token_type_ids': ex.token_type_ids},
|
||||
ex.label)
|
||||
|
@ -18,19 +18,20 @@ if is_tf_available():
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
|
||||
orig_answer_text):
|
||||
|
||||
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text):
|
||||
"""Returns tokenized answer spans that better match the annotated answer."""
|
||||
tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
|
||||
|
||||
for new_start in range(input_start, input_end + 1):
|
||||
for new_end in range(input_end, new_start - 1, -1):
|
||||
text_span = " ".join(doc_tokens[new_start:(new_end + 1)])
|
||||
text_span = " ".join(doc_tokens[new_start : (new_end + 1)])
|
||||
if text_span == tok_answer_text:
|
||||
return (new_start, new_end)
|
||||
|
||||
return (input_start, input_end)
|
||||
|
||||
|
||||
def _check_is_max_context(doc_spans, cur_span_index, position):
|
||||
"""Check if this is the 'max context' doc span for the token."""
|
||||
best_score = None
|
||||
@ -50,10 +51,11 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
|
||||
|
||||
return cur_span_index == best_span_index
|
||||
|
||||
|
||||
def _new_check_is_max_context(doc_spans, cur_span_index, position):
|
||||
"""Check if this is the 'max context' doc span for the token."""
|
||||
# if len(doc_spans) == 1:
|
||||
# return True
|
||||
# return True
|
||||
best_score = None
|
||||
best_span_index = None
|
||||
for (span_index, doc_span) in enumerate(doc_spans):
|
||||
@ -71,14 +73,16 @@ def _new_check_is_max_context(doc_spans, cur_span_index, position):
|
||||
|
||||
return cur_span_index == best_span_index
|
||||
|
||||
|
||||
def _is_whitespace(c):
|
||||
if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
|
||||
return True
|
||||
return False
|
||||
|
||||
def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
doc_stride, max_query_length, is_training,
|
||||
return_dataset=False):
|
||||
|
||||
def squad_convert_examples_to_features(
|
||||
examples, tokenizer, max_seq_length, doc_stride, max_query_length, is_training, return_dataset=False
|
||||
):
|
||||
"""
|
||||
Converts a list of examples into a list of features that can be directly given as input to a model.
|
||||
It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
|
||||
@ -112,24 +116,23 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
)
|
||||
"""
|
||||
|
||||
# Defining helper methods
|
||||
# Defining helper methods
|
||||
unique_id = 1000000000
|
||||
|
||||
features = []
|
||||
for (example_index, example) in enumerate(tqdm(examples)):
|
||||
for (example_index, example) in enumerate(tqdm(examples, desc="Converting examples to features")):
|
||||
if is_training and not example.is_impossible:
|
||||
# Get start and end position
|
||||
start_position = example.start_position
|
||||
end_position = example.end_position
|
||||
|
||||
# If the answer cannot be found in the text, then skip this example.
|
||||
actual_text = " ".join(example.doc_tokens[start_position:(end_position + 1)])
|
||||
actual_text = " ".join(example.doc_tokens[start_position : (end_position + 1)])
|
||||
cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text))
|
||||
if actual_text.find(cleaned_answer_text) == -1:
|
||||
logger.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text)
|
||||
continue
|
||||
|
||||
|
||||
tok_to_orig_index = []
|
||||
orig_to_tok_index = []
|
||||
all_doc_tokens = []
|
||||
@ -140,7 +143,6 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
tok_to_orig_index.append(i)
|
||||
all_doc_tokens.append(sub_token)
|
||||
|
||||
|
||||
if is_training and not example.is_impossible:
|
||||
tok_start_position = orig_to_tok_index[example.start_position]
|
||||
if example.end_position < len(example.doc_tokens) - 1:
|
||||
@ -153,36 +155,41 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
)
|
||||
|
||||
spans = []
|
||||
|
||||
truncated_query = tokenizer.encode(example.question_text, add_special_tokens=False, max_length=max_query_length)
|
||||
sequence_added_tokens = tokenizer.max_len - tokenizer.max_len_single_sentence
|
||||
sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair
|
||||
|
||||
truncated_query = tokenizer.encode(
|
||||
example.question_text, add_special_tokens=False, max_length=max_query_length
|
||||
)
|
||||
sequence_added_tokens = tokenizer.max_len - tokenizer.max_len_single_sentence
|
||||
sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair
|
||||
|
||||
span_doc_tokens = all_doc_tokens
|
||||
while len(spans) * doc_stride < len(all_doc_tokens):
|
||||
|
||||
|
||||
encoded_dict = tokenizer.encode_plus(
|
||||
truncated_query if tokenizer.padding_side == "right" else span_doc_tokens,
|
||||
span_doc_tokens if tokenizer.padding_side == "right" else truncated_query,
|
||||
max_length=max_seq_length,
|
||||
return_overflowing_tokens=True,
|
||||
truncated_query if tokenizer.padding_side == "right" else span_doc_tokens,
|
||||
span_doc_tokens if tokenizer.padding_side == "right" else truncated_query,
|
||||
max_length=max_seq_length,
|
||||
return_overflowing_tokens=True,
|
||||
pad_to_max_length=True,
|
||||
stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,
|
||||
truncation_strategy='only_second' if tokenizer.padding_side == "right" else 'only_first'
|
||||
truncation_strategy="only_second" if tokenizer.padding_side == "right" else "only_first",
|
||||
)
|
||||
|
||||
paragraph_len = min(len(all_doc_tokens) - len(spans) * doc_stride, max_seq_length - len(truncated_query) - sequence_pair_added_tokens)
|
||||
paragraph_len = min(
|
||||
len(all_doc_tokens) - len(spans) * doc_stride,
|
||||
max_seq_length - len(truncated_query) - sequence_pair_added_tokens,
|
||||
)
|
||||
|
||||
if tokenizer.pad_token_id in encoded_dict['input_ids']:
|
||||
non_padded_ids = encoded_dict['input_ids'][:encoded_dict['input_ids'].index(tokenizer.pad_token_id)]
|
||||
if tokenizer.pad_token_id in encoded_dict["input_ids"]:
|
||||
non_padded_ids = encoded_dict["input_ids"][: encoded_dict["input_ids"].index(tokenizer.pad_token_id)]
|
||||
else:
|
||||
non_padded_ids = encoded_dict['input_ids']
|
||||
non_padded_ids = encoded_dict["input_ids"]
|
||||
|
||||
tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)
|
||||
|
||||
token_to_orig_map = {}
|
||||
for i in range(paragraph_len):
|
||||
index = len(truncated_query) + sequence_added_tokens + i if tokenizer.padding_side == "right" else i
|
||||
index = len(truncated_query) + sequence_added_tokens + i if tokenizer.padding_side == "right" else i
|
||||
token_to_orig_map[index] = tok_to_orig_index[len(spans) * doc_stride + i]
|
||||
|
||||
encoded_dict["paragraph_len"] = paragraph_len
|
||||
@ -202,16 +209,20 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
for doc_span_index in range(len(spans)):
|
||||
for j in range(spans[doc_span_index]["paragraph_len"]):
|
||||
is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j)
|
||||
index = j if tokenizer.padding_side == "left" else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j
|
||||
index = (
|
||||
j
|
||||
if tokenizer.padding_side == "left"
|
||||
else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j
|
||||
)
|
||||
spans[doc_span_index]["token_is_max_context"][index] = is_max_context
|
||||
|
||||
for span in spans:
|
||||
# Identify the position of the CLS token
|
||||
cls_index = span['input_ids'].index(tokenizer.cls_token_id)
|
||||
cls_index = span["input_ids"].index(tokenizer.cls_token_id)
|
||||
|
||||
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
|
||||
# Original TF implem also keep the classification token (set to 0) (not sure why...)
|
||||
p_mask = np.array(span['token_type_ids'])
|
||||
p_mask = np.array(span["token_type_ids"])
|
||||
|
||||
p_mask = np.minimum(p_mask, 1)
|
||||
|
||||
@ -224,7 +235,6 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
# Set the CLS index to '0'
|
||||
p_mask[cls_index] = 0
|
||||
|
||||
|
||||
span_is_impossible = example.is_impossible
|
||||
start_position = 0
|
||||
end_position = 0
|
||||
@ -247,55 +257,99 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
doc_offset = 0
|
||||
else:
|
||||
doc_offset = len(truncated_query) + sequence_added_tokens
|
||||
|
||||
|
||||
start_position = tok_start_position - doc_start + doc_offset
|
||||
end_position = tok_end_position - doc_start + doc_offset
|
||||
|
||||
|
||||
features.append(SquadFeatures(
|
||||
span['input_ids'],
|
||||
span['attention_mask'],
|
||||
span['token_type_ids'],
|
||||
cls_index,
|
||||
p_mask.tolist(),
|
||||
|
||||
example_index=example_index,
|
||||
unique_id=unique_id,
|
||||
paragraph_len=span['paragraph_len'],
|
||||
token_is_max_context=span["token_is_max_context"],
|
||||
tokens=span["tokens"],
|
||||
token_to_orig_map=span["token_to_orig_map"],
|
||||
|
||||
start_position=start_position,
|
||||
end_position=end_position
|
||||
))
|
||||
features.append(
|
||||
SquadFeatures(
|
||||
span["input_ids"],
|
||||
span["attention_mask"],
|
||||
span["token_type_ids"],
|
||||
cls_index,
|
||||
p_mask.tolist(),
|
||||
example_index=example_index,
|
||||
unique_id=unique_id,
|
||||
paragraph_len=span["paragraph_len"],
|
||||
token_is_max_context=span["token_is_max_context"],
|
||||
tokens=span["tokens"],
|
||||
token_to_orig_map=span["token_to_orig_map"],
|
||||
start_position=start_position,
|
||||
end_position=end_position,
|
||||
)
|
||||
)
|
||||
|
||||
unique_id += 1
|
||||
|
||||
if return_dataset == 'pt':
|
||||
if return_dataset == "pt":
|
||||
if not is_torch_available():
|
||||
raise ImportError("Pytorch must be installed to return a pytorch dataset.")
|
||||
|
||||
# Convert to Tensors and build dataset
|
||||
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
|
||||
all_attention_masks = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
|
||||
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
|
||||
all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
|
||||
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
|
||||
|
||||
if not is_training:
|
||||
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
|
||||
all_example_index, all_cls_index, all_p_mask)
|
||||
dataset = TensorDataset(
|
||||
all_input_ids, all_attention_masks, all_token_type_ids, all_example_index, all_cls_index, all_p_mask
|
||||
)
|
||||
else:
|
||||
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
|
||||
all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
|
||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
|
||||
all_start_positions, all_end_positions,
|
||||
all_cls_index, all_p_mask)
|
||||
dataset = TensorDataset(
|
||||
all_input_ids,
|
||||
all_attention_masks,
|
||||
all_token_type_ids,
|
||||
all_start_positions,
|
||||
all_end_positions,
|
||||
all_cls_index,
|
||||
all_p_mask,
|
||||
)
|
||||
|
||||
return features, dataset
|
||||
|
||||
elif return_dataset == "tf":
|
||||
if not is_tf_available():
|
||||
raise ImportError("TensorFlow must be installed to return a TensorFlow dataset.")
|
||||
|
||||
def gen():
|
||||
for ex in features:
|
||||
yield (
|
||||
{
|
||||
"input_ids": ex.input_ids,
|
||||
"attention_mask": ex.attention_mask,
|
||||
"token_type_ids": ex.token_type_ids,
|
||||
}, {
|
||||
"start_position": ex.start_position,
|
||||
"end_position": ex.end_position,
|
||||
"cls_index": ex.cls_index,
|
||||
"p_mask": ex.p_mask,
|
||||
}
|
||||
)
|
||||
|
||||
return tf.data.Dataset.from_generator(
|
||||
gen,
|
||||
(
|
||||
{"input_ids": tf.int32, "attention_mask": tf.int32, "token_type_ids": tf.int32},
|
||||
{"start_position": tf.int64, "end_position": tf.int64, "cls_index": tf.int64, "p_mask": tf.int32},
|
||||
),
|
||||
(
|
||||
{
|
||||
"input_ids": tf.TensorShape([None]),
|
||||
"attention_mask": tf.TensorShape([None]),
|
||||
"token_type_ids": tf.TensorShape([None]),
|
||||
},
|
||||
{
|
||||
"start_position": tf.TensorShape([]),
|
||||
"end_position": tf.TensorShape([]),
|
||||
"cls_index": tf.TensorShape([]),
|
||||
"p_mask": tf.TensorShape([None]),
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
@ -305,31 +359,32 @@ class SquadProcessor(DataProcessor):
|
||||
Processor for the SQuAD data set.
|
||||
Overriden by SquadV1Processor and SquadV2Processor, used by the version 1.1 and version 2.0 of SQuAD, respectively.
|
||||
"""
|
||||
|
||||
train_file = None
|
||||
dev_file = None
|
||||
|
||||
def _get_example_from_tensor_dict(self, tensor_dict, evaluate=False):
|
||||
if not evaluate:
|
||||
answer = tensor_dict['answers']['text'][0].numpy().decode('utf-8')
|
||||
answer_start = tensor_dict['answers']['answer_start'][0].numpy()
|
||||
answer = tensor_dict["answers"]["text"][0].numpy().decode("utf-8")
|
||||
answer_start = tensor_dict["answers"]["answer_start"][0].numpy()
|
||||
answers = []
|
||||
else:
|
||||
answers = [{
|
||||
"answer_start": start.numpy(),
|
||||
"text": text.numpy().decode('utf-8')
|
||||
} for start, text in zip(tensor_dict['answers']["answer_start"], tensor_dict['answers']["text"])]
|
||||
answers = [
|
||||
{"answer_start": start.numpy(), "text": text.numpy().decode("utf-8")}
|
||||
for start, text in zip(tensor_dict["answers"]["answer_start"], tensor_dict["answers"]["text"])
|
||||
]
|
||||
|
||||
answer = None
|
||||
answer_start = None
|
||||
|
||||
return SquadExample(
|
||||
qas_id=tensor_dict['id'].numpy().decode("utf-8"),
|
||||
question_text=tensor_dict['question'].numpy().decode('utf-8'),
|
||||
context_text=tensor_dict['context'].numpy().decode('utf-8'),
|
||||
qas_id=tensor_dict["id"].numpy().decode("utf-8"),
|
||||
question_text=tensor_dict["question"].numpy().decode("utf-8"),
|
||||
context_text=tensor_dict["context"].numpy().decode("utf-8"),
|
||||
answer_text=answer,
|
||||
start_position_character=answer_start,
|
||||
title=tensor_dict['title'].numpy().decode('utf-8'),
|
||||
answers=answers
|
||||
title=tensor_dict["title"].numpy().decode("utf-8"),
|
||||
answers=answers,
|
||||
)
|
||||
|
||||
def get_examples_from_dataset(self, dataset, evaluate=False):
|
||||
@ -359,7 +414,7 @@ class SquadProcessor(DataProcessor):
|
||||
|
||||
examples = []
|
||||
for tensor_dict in tqdm(dataset):
|
||||
examples.append(self._get_example_from_tensor_dict(tensor_dict, evaluate=evaluate))
|
||||
examples.append(self._get_example_from_tensor_dict(tensor_dict, evaluate=evaluate))
|
||||
|
||||
return examples
|
||||
|
||||
@ -379,7 +434,9 @@ class SquadProcessor(DataProcessor):
|
||||
if self.train_file is None:
|
||||
raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")
|
||||
|
||||
with open(os.path.join(data_dir, self.train_file if filename is None else filename), "r", encoding='utf-8') as reader:
|
||||
with open(
|
||||
os.path.join(data_dir, self.train_file if filename is None else filename), "r", encoding="utf-8"
|
||||
) as reader:
|
||||
input_data = json.load(reader)["data"]
|
||||
return self._create_examples(input_data, "train")
|
||||
|
||||
@ -397,8 +454,10 @@ class SquadProcessor(DataProcessor):
|
||||
|
||||
if self.dev_file is None:
|
||||
raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")
|
||||
|
||||
with open(os.path.join(data_dir, self.dev_file if filename is None else filename), "r", encoding='utf-8') as reader:
|
||||
|
||||
with open(
|
||||
os.path.join(data_dir, self.dev_file if filename is None else filename), "r", encoding="utf-8"
|
||||
) as reader:
|
||||
input_data = json.load(reader)["data"]
|
||||
return self._create_examples(input_data, "dev")
|
||||
|
||||
@ -406,7 +465,7 @@ class SquadProcessor(DataProcessor):
|
||||
is_training = set_type == "train"
|
||||
examples = []
|
||||
for entry in tqdm(input_data):
|
||||
title = entry['title']
|
||||
title = entry["title"]
|
||||
for paragraph in entry["paragraphs"]:
|
||||
context_text = paragraph["context"]
|
||||
for qa in paragraph["qas"]:
|
||||
@ -415,7 +474,7 @@ class SquadProcessor(DataProcessor):
|
||||
start_position_character = None
|
||||
answer_text = None
|
||||
answers = []
|
||||
|
||||
|
||||
if "is_impossible" in qa:
|
||||
is_impossible = qa["is_impossible"]
|
||||
else:
|
||||
@ -424,8 +483,8 @@ class SquadProcessor(DataProcessor):
|
||||
if not is_impossible:
|
||||
if is_training:
|
||||
answer = qa["answers"][0]
|
||||
answer_text = answer['text']
|
||||
start_position_character = answer['answer_start']
|
||||
answer_text = answer["text"]
|
||||
start_position_character = answer["answer_start"]
|
||||
else:
|
||||
answers = qa["answers"]
|
||||
|
||||
@ -437,12 +496,13 @@ class SquadProcessor(DataProcessor):
|
||||
start_position_character=start_position_character,
|
||||
title=title,
|
||||
is_impossible=is_impossible,
|
||||
answers=answers
|
||||
answers=answers,
|
||||
)
|
||||
|
||||
examples.append(example)
|
||||
return examples
|
||||
|
||||
|
||||
class SquadV1Processor(SquadProcessor):
|
||||
train_file = "train-v1.1.json"
|
||||
dev_file = "dev-v1.1.json"
|
||||
@ -451,7 +511,7 @@ class SquadV1Processor(SquadProcessor):
|
||||
class SquadV2Processor(SquadProcessor):
|
||||
train_file = "train-v2.0.json"
|
||||
dev_file = "dev-v2.0.json"
|
||||
|
||||
|
||||
|
||||
class SquadExample(object):
|
||||
"""
|
||||
@ -468,21 +528,23 @@ class SquadExample(object):
|
||||
is_impossible: False by default, set to True if the example has no possible answer.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
qas_id,
|
||||
question_text,
|
||||
context_text,
|
||||
answer_text,
|
||||
start_position_character,
|
||||
title,
|
||||
answers=[],
|
||||
is_impossible=False):
|
||||
def __init__(
|
||||
self,
|
||||
qas_id,
|
||||
question_text,
|
||||
context_text,
|
||||
answer_text,
|
||||
start_position_character,
|
||||
title,
|
||||
answers=[],
|
||||
is_impossible=False,
|
||||
):
|
||||
self.qas_id = qas_id
|
||||
self.question_text = question_text
|
||||
self.context_text = context_text
|
||||
self.answer_text = answer_text
|
||||
self.title = title
|
||||
self.is_impossible = is_impossible
|
||||
self.is_impossible = is_impossible
|
||||
self.answers = answers
|
||||
|
||||
self.start_position, self.end_position = 0, 0
|
||||
@ -537,24 +599,23 @@ class SquadFeatures(object):
|
||||
end_position: end of the answer token index
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
cls_index,
|
||||
p_mask,
|
||||
|
||||
example_index,
|
||||
unique_id,
|
||||
paragraph_len,
|
||||
token_is_max_context,
|
||||
tokens,
|
||||
token_to_orig_map,
|
||||
|
||||
start_position,
|
||||
end_position
|
||||
):
|
||||
self.input_ids = input_ids
|
||||
def __init__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
cls_index,
|
||||
p_mask,
|
||||
example_index,
|
||||
unique_id,
|
||||
paragraph_len,
|
||||
token_is_max_context,
|
||||
tokens,
|
||||
token_to_orig_map,
|
||||
start_position,
|
||||
end_position,
|
||||
):
|
||||
self.input_ids = input_ids
|
||||
self.attention_mask = attention_mask
|
||||
self.token_type_ids = token_type_ids
|
||||
self.cls_index = cls_index
|
||||
@ -580,12 +641,13 @@ class SquadResult(object):
|
||||
start_logits: The logits corresponding to the start of the answer
|
||||
end_logits: The logits corresponding to the end of the answer
|
||||
"""
|
||||
|
||||
def __init__(self, unique_id, start_logits, end_logits, start_top_index=None, end_top_index=None, cls_logits=None):
|
||||
self.start_logits = start_logits
|
||||
self.end_logits = end_logits
|
||||
self.unique_id = unique_id
|
||||
|
||||
|
||||
if start_top_index:
|
||||
self.start_top_index = start_top_index
|
||||
self.end_top_index = end_top_index
|
||||
self.cls_logits = cls_logits
|
||||
self.cls_logits = cls_logits
|
||||
|
@ -139,5 +139,6 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
assert encoded_sentence == [101] + text + [102]
|
||||
assert encoded_pair == [101] + text + [102] + text_2 + [102]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
@ -67,6 +67,5 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
@ -232,6 +232,15 @@ class CommonTestCases:
|
||||
self.assertNotEqual(len(tokens_2), 0)
|
||||
self.assertIsInstance(text_2, (str, unicode))
|
||||
|
||||
def test_encode_decode_with_spaces(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
new_toks = ['[ABC]', '[DEF]', 'GHI IHG']
|
||||
tokenizer.add_tokens(new_toks)
|
||||
input = "[ABC] [DEF] [ABC] GHI IHG [DEF]"
|
||||
encoded = tokenizer.encode(input, add_special_tokens=False)
|
||||
decoded = tokenizer.decode(encoded)
|
||||
self.assertEqual(decoded, input)
|
||||
|
||||
def test_pretrained_model_lists(self):
|
||||
weights_list = list(self.tokenizer_class.max_model_input_sizes.keys())
|
||||
|
@ -637,9 +637,11 @@ class PreTrainedTokenizer(object):
|
||||
text: The sequence to be encoded.
|
||||
**kwargs: passed to the child `self.tokenize()` method
|
||||
"""
|
||||
all_special_tokens = self.all_special_tokens
|
||||
|
||||
def lowercase_text(t):
|
||||
# convert non-special tokens to lowercase
|
||||
escaped_special_toks = [re.escape(s_tok) for s_tok in self.all_special_tokens]
|
||||
escaped_special_toks = [re.escape(s_tok) for s_tok in all_special_tokens]
|
||||
pattern = r'(^' + r'|'.join(escaped_special_toks) + r')|' + \
|
||||
r'(.+?)'
|
||||
return re.sub(
|
||||
@ -680,17 +682,17 @@ class PreTrainedTokenizer(object):
|
||||
tokenized_text = []
|
||||
for sub_text in text_list:
|
||||
if sub_text not in self.added_tokens_encoder \
|
||||
and sub_text not in self.all_special_tokens:
|
||||
and sub_text not in all_special_tokens:
|
||||
tokenized_text += split_on_token(tok, sub_text)
|
||||
else:
|
||||
tokenized_text += [sub_text]
|
||||
text_list = tokenized_text
|
||||
|
||||
return list(itertools.chain.from_iterable((self._tokenize(token, **kwargs) if token not \
|
||||
in self.added_tokens_encoder and token not in self.all_special_tokens \
|
||||
in self.added_tokens_encoder and token not in all_special_tokens \
|
||||
else [token] for token in tokenized_text)))
|
||||
|
||||
added_tokens = list(self.added_tokens_encoder.keys()) + self.all_special_tokens
|
||||
added_tokens = list(self.added_tokens_encoder.keys()) + all_special_tokens
|
||||
tokenized_text = split_on_tokens(added_tokens, text)
|
||||
return tokenized_text
|
||||
|
||||
@ -1178,12 +1180,12 @@ class PreTrainedTokenizer(object):
|
||||
if current_sub_text:
|
||||
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
|
||||
current_sub_text = []
|
||||
sub_texts.append(" " + token)
|
||||
sub_texts.append(token)
|
||||
else:
|
||||
current_sub_text.append(token)
|
||||
if current_sub_text:
|
||||
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
|
||||
text = ''.join(sub_texts)
|
||||
text = ' '.join(sub_texts)
|
||||
|
||||
if clean_up_tokenization_spaces:
|
||||
clean_text = self.clean_up_tokenization(text)
|
||||
|
Loading…
Reference in New Issue
Block a user