From 5a5c4349e8a141d2c0915d71cb3cee101da0db6f Mon Sep 17 00:00:00 2001 From: Pierric Cistac Date: Fri, 13 Dec 2019 10:02:33 -0500 Subject: [PATCH 01/12] Fix summarization `to_cpu` doc --- examples/summarization/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/summarization/README.md b/examples/summarization/README.md index 96825cfa465..b98581e8e5f 100644 --- a/examples/summarization/README.md +++ b/examples/summarization/README.md @@ -29,7 +29,7 @@ And move all the stories to the same folder. We will refer as `$DATA_PATH` the p python run_summarization.py \ --documents_dir $DATA_PATH \ --summaries_output_dir $SUMMARIES_PATH \ # optional - --to_cpu false \ + --no_cuda false \ --batch_size 4 \ --min_length 50 \ --max_length 200 \ @@ -39,7 +39,7 @@ python run_summarization.py \ --compute_rouge true ``` -The scripts executes on GPU if one is available and if `to_cpu` is not set to `true`. Inference on multiple GPUs is not suported yet. The ROUGE scores will be displayed in the console at the end of evaluation and written in a `rouge_scores.txt` file. The script takes 30 hours to compute with a single Tesla V100 GPU and a batch size of 10 (300,000 texts to summarize). +The scripts executes on GPU if one is available and if `no_cuda` is not set to `true`. Inference on multiple GPUs is not suported yet. The ROUGE scores will be displayed in the console at the end of evaluation and written in a `rouge_scores.txt` file. The script takes 30 hours to compute with a single Tesla V100 GPU and a batch size of 10 (300,000 texts to summarize). ## Summarize any text @@ -49,7 +49,7 @@ Put the documents that you would like to summarize in a folder (the path to whic python run_summarization.py \ --documents_dir $DATA_PATH \ --summaries_output_dir $SUMMARIES_PATH \ # optional - --to_cpu false \ + --no_cuda false \ --batch_size 4 \ --min_length 50 \ --max_length 200 \ From c8ed1c82c8a42ef700d4129d227fa356385c1d60 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Fri, 13 Dec 2019 12:13:48 -0500 Subject: [PATCH 02/12] [SQUAD] Load checkpoint when evaluating without training --- examples/run_squad.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/examples/run_squad.py b/examples/run_squad.py index 117b86e32cd..a39915ee8bc 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -580,10 +580,16 @@ def main(): # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory results = {} if args.do_eval and args.local_rank in [-1, 0]: - checkpoints = [args.output_dir] - if args.eval_all_checkpoints: - checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) - logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs + + if args.do_train: + logger.info("Loading checkpoints saved during training for evaluation") + checkpoints = [args.output_dir] + if args.eval_all_checkpoints: + checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) + logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs + else: + logger.info("Loading checkpoint %s for evaluation", args.model_name_or_path) + checkpoints = [args.model_name_or_path] logger.info("Evaluate the following checkpoints: %s", checkpoints) From f24a228a9315a4b723509bc9144b53d2bcbc4217 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Fri, 13 Dec 2019 14:50:35 -0500 Subject: [PATCH 03/12] Speed up tokenization process --- transformers/data/processors/squad.py | 2 +- transformers/tokenization_utils.py | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/transformers/data/processors/squad.py b/transformers/data/processors/squad.py index 9bc43756842..e193f6153e8 100644 --- a/transformers/data/processors/squad.py +++ b/transformers/data/processors/squad.py @@ -116,7 +116,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, unique_id = 1000000000 features = [] - for (example_index, example) in enumerate(tqdm(examples)): + for (example_index, example) in enumerate(tqdm(examples, desc="Converting examples to features")): if is_training and not example.is_impossible: # Get start and end position start_position = example.start_position diff --git a/transformers/tokenization_utils.py b/transformers/tokenization_utils.py index 317ecd167b7..e87c87787b9 100644 --- a/transformers/tokenization_utils.py +++ b/transformers/tokenization_utils.py @@ -637,9 +637,11 @@ class PreTrainedTokenizer(object): text: The sequence to be encoded. **kwargs: passed to the child `self.tokenize()` method """ + all_special_tokens = self.all_special_tokens + def lowercase_text(t): # convert non-special tokens to lowercase - escaped_special_toks = [re.escape(s_tok) for s_tok in self.all_special_tokens] + escaped_special_toks = [re.escape(s_tok) for s_tok in all_special_tokens] pattern = r'(^' + r'|'.join(escaped_special_toks) + r')|' + \ r'(.+?)' return re.sub( @@ -680,17 +682,17 @@ class PreTrainedTokenizer(object): tokenized_text = [] for sub_text in text_list: if sub_text not in self.added_tokens_encoder \ - and sub_text not in self.all_special_tokens: + and sub_text not in all_special_tokens: tokenized_text += split_on_token(tok, sub_text) else: tokenized_text += [sub_text] text_list = tokenized_text return list(itertools.chain.from_iterable((self._tokenize(token, **kwargs) if token not \ - in self.added_tokens_encoder and token not in self.all_special_tokens \ + in self.added_tokens_encoder and token not in all_special_tokens \ else [token] for token in tokenized_text))) - added_tokens = list(self.added_tokens_encoder.keys()) + self.all_special_tokens + added_tokens = list(self.added_tokens_encoder.keys()) + all_special_tokens tokenized_text = split_on_tokens(added_tokens, text) return tokenized_text From d46147294852694d1dc701c72b9053ff2e726265 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Fri, 13 Dec 2019 15:31:52 -0500 Subject: [PATCH 04/12] return for SQuAD [BLACKED] --- transformers/data/processors/glue.py | 2 +- transformers/data/processors/squad.py | 280 ++++++++++++++++---------- 2 files changed, 172 insertions(+), 110 deletions(-) diff --git a/transformers/data/processors/glue.py b/transformers/data/processors/glue.py index 518251b0503..11ebd949def 100644 --- a/transformers/data/processors/glue.py +++ b/transformers/data/processors/glue.py @@ -133,7 +133,7 @@ def glue_convert_examples_to_features(examples, tokenizer, if is_tf_available() and is_tf_dataset: def gen(): for ex in features: - yield ({'input_ids': ex.input_ids, + yield ({'input_ids': ex.input_ids, 'attention_mask': ex.attention_mask, 'token_type_ids': ex.token_type_ids}, ex.label) diff --git a/transformers/data/processors/squad.py b/transformers/data/processors/squad.py index e193f6153e8..84aa429e26a 100644 --- a/transformers/data/processors/squad.py +++ b/transformers/data/processors/squad.py @@ -18,19 +18,20 @@ if is_tf_available(): logger = logging.getLogger(__name__) -def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, - orig_answer_text): + +def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text): """Returns tokenized answer spans that better match the annotated answer.""" tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) for new_start in range(input_start, input_end + 1): for new_end in range(input_end, new_start - 1, -1): - text_span = " ".join(doc_tokens[new_start:(new_end + 1)]) + text_span = " ".join(doc_tokens[new_start : (new_end + 1)]) if text_span == tok_answer_text: return (new_start, new_end) return (input_start, input_end) + def _check_is_max_context(doc_spans, cur_span_index, position): """Check if this is the 'max context' doc span for the token.""" best_score = None @@ -50,10 +51,11 @@ def _check_is_max_context(doc_spans, cur_span_index, position): return cur_span_index == best_span_index + def _new_check_is_max_context(doc_spans, cur_span_index, position): """Check if this is the 'max context' doc span for the token.""" # if len(doc_spans) == 1: - # return True + # return True best_score = None best_span_index = None for (span_index, doc_span) in enumerate(doc_spans): @@ -71,14 +73,16 @@ def _new_check_is_max_context(doc_spans, cur_span_index, position): return cur_span_index == best_span_index + def _is_whitespace(c): if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: return True return False -def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, - doc_stride, max_query_length, is_training, - return_dataset=False): + +def squad_convert_examples_to_features( + examples, tokenizer, max_seq_length, doc_stride, max_query_length, is_training, return_dataset=False +): """ Converts a list of examples into a list of features that can be directly given as input to a model. It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs. @@ -112,7 +116,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, ) """ - # Defining helper methods + # Defining helper methods unique_id = 1000000000 features = [] @@ -123,13 +127,12 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, end_position = example.end_position # If the answer cannot be found in the text, then skip this example. - actual_text = " ".join(example.doc_tokens[start_position:(end_position + 1)]) + actual_text = " ".join(example.doc_tokens[start_position : (end_position + 1)]) cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text)) if actual_text.find(cleaned_answer_text) == -1: logger.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text) continue - tok_to_orig_index = [] orig_to_tok_index = [] all_doc_tokens = [] @@ -140,7 +143,6 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, tok_to_orig_index.append(i) all_doc_tokens.append(sub_token) - if is_training and not example.is_impossible: tok_start_position = orig_to_tok_index[example.start_position] if example.end_position < len(example.doc_tokens) - 1: @@ -153,36 +155,41 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, ) spans = [] - - truncated_query = tokenizer.encode(example.question_text, add_special_tokens=False, max_length=max_query_length) - sequence_added_tokens = tokenizer.max_len - tokenizer.max_len_single_sentence - sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair + + truncated_query = tokenizer.encode( + example.question_text, add_special_tokens=False, max_length=max_query_length + ) + sequence_added_tokens = tokenizer.max_len - tokenizer.max_len_single_sentence + sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair span_doc_tokens = all_doc_tokens while len(spans) * doc_stride < len(all_doc_tokens): - + encoded_dict = tokenizer.encode_plus( - truncated_query if tokenizer.padding_side == "right" else span_doc_tokens, - span_doc_tokens if tokenizer.padding_side == "right" else truncated_query, - max_length=max_seq_length, - return_overflowing_tokens=True, + truncated_query if tokenizer.padding_side == "right" else span_doc_tokens, + span_doc_tokens if tokenizer.padding_side == "right" else truncated_query, + max_length=max_seq_length, + return_overflowing_tokens=True, pad_to_max_length=True, stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens, - truncation_strategy='only_second' if tokenizer.padding_side == "right" else 'only_first' + truncation_strategy="only_second" if tokenizer.padding_side == "right" else "only_first", ) - paragraph_len = min(len(all_doc_tokens) - len(spans) * doc_stride, max_seq_length - len(truncated_query) - sequence_pair_added_tokens) + paragraph_len = min( + len(all_doc_tokens) - len(spans) * doc_stride, + max_seq_length - len(truncated_query) - sequence_pair_added_tokens, + ) - if tokenizer.pad_token_id in encoded_dict['input_ids']: - non_padded_ids = encoded_dict['input_ids'][:encoded_dict['input_ids'].index(tokenizer.pad_token_id)] + if tokenizer.pad_token_id in encoded_dict["input_ids"]: + non_padded_ids = encoded_dict["input_ids"][: encoded_dict["input_ids"].index(tokenizer.pad_token_id)] else: - non_padded_ids = encoded_dict['input_ids'] + non_padded_ids = encoded_dict["input_ids"] tokens = tokenizer.convert_ids_to_tokens(non_padded_ids) token_to_orig_map = {} for i in range(paragraph_len): - index = len(truncated_query) + sequence_added_tokens + i if tokenizer.padding_side == "right" else i + index = len(truncated_query) + sequence_added_tokens + i if tokenizer.padding_side == "right" else i token_to_orig_map[index] = tok_to_orig_index[len(spans) * doc_stride + i] encoded_dict["paragraph_len"] = paragraph_len @@ -202,16 +209,20 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, for doc_span_index in range(len(spans)): for j in range(spans[doc_span_index]["paragraph_len"]): is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j) - index = j if tokenizer.padding_side == "left" else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j + index = ( + j + if tokenizer.padding_side == "left" + else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j + ) spans[doc_span_index]["token_is_max_context"][index] = is_max_context for span in spans: # Identify the position of the CLS token - cls_index = span['input_ids'].index(tokenizer.cls_token_id) + cls_index = span["input_ids"].index(tokenizer.cls_token_id) # p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer) # Original TF implem also keep the classification token (set to 0) (not sure why...) - p_mask = np.array(span['token_type_ids']) + p_mask = np.array(span["token_type_ids"]) p_mask = np.minimum(p_mask, 1) @@ -224,7 +235,6 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, # Set the CLS index to '0' p_mask[cls_index] = 0 - span_is_impossible = example.is_impossible start_position = 0 end_position = 0 @@ -247,55 +257,99 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, doc_offset = 0 else: doc_offset = len(truncated_query) + sequence_added_tokens - + start_position = tok_start_position - doc_start + doc_offset end_position = tok_end_position - doc_start + doc_offset - - features.append(SquadFeatures( - span['input_ids'], - span['attention_mask'], - span['token_type_ids'], - cls_index, - p_mask.tolist(), - - example_index=example_index, - unique_id=unique_id, - paragraph_len=span['paragraph_len'], - token_is_max_context=span["token_is_max_context"], - tokens=span["tokens"], - token_to_orig_map=span["token_to_orig_map"], - - start_position=start_position, - end_position=end_position - )) + features.append( + SquadFeatures( + span["input_ids"], + span["attention_mask"], + span["token_type_ids"], + cls_index, + p_mask.tolist(), + example_index=example_index, + unique_id=unique_id, + paragraph_len=span["paragraph_len"], + token_is_max_context=span["token_is_max_context"], + tokens=span["tokens"], + token_to_orig_map=span["token_to_orig_map"], + start_position=start_position, + end_position=end_position, + ) + ) unique_id += 1 - if return_dataset == 'pt': + if return_dataset == "pt": if not is_torch_available(): raise ImportError("Pytorch must be installed to return a pytorch dataset.") # Convert to Tensors and build dataset all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) - all_input_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) - all_segment_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) + all_attention_masks = torch.tensor([f.attention_mask for f in features], dtype=torch.long) + all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long) all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float) if not is_training: all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) - dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, - all_example_index, all_cls_index, all_p_mask) + dataset = TensorDataset( + all_input_ids, all_attention_masks, all_token_type_ids, all_example_index, all_cls_index, all_p_mask + ) else: all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long) all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long) - dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, - all_start_positions, all_end_positions, - all_cls_index, all_p_mask) + dataset = TensorDataset( + all_input_ids, + all_attention_masks, + all_token_type_ids, + all_start_positions, + all_end_positions, + all_cls_index, + all_p_mask, + ) return features, dataset - + elif return_dataset == "tf": + if not is_tf_available(): + raise ImportError("TensorFlow must be installed to return a TensorFlow dataset.") + + def gen(): + for ex in features: + yield ( + { + "input_ids": ex.input_ids, + "attention_mask": ex.attention_mask, + "token_type_ids": ex.token_type_ids, + }, { + "start_position": ex.start_position, + "end_position": ex.end_position, + "cls_index": ex.cls_index, + "p_mask": ex.p_mask, + } + ) + + return tf.data.Dataset.from_generator( + gen, + ( + {"input_ids": tf.int32, "attention_mask": tf.int32, "token_type_ids": tf.int32}, + {"start_position": tf.int64, "end_position": tf.int64, "cls_index": tf.int64, "p_mask": tf.int32}, + ), + ( + { + "input_ids": tf.TensorShape([None]), + "attention_mask": tf.TensorShape([None]), + "token_type_ids": tf.TensorShape([None]), + }, + { + "start_position": tf.TensorShape([]), + "end_position": tf.TensorShape([]), + "cls_index": tf.TensorShape([]), + "p_mask": tf.TensorShape([None]), + }, + ), + ) return features @@ -305,31 +359,32 @@ class SquadProcessor(DataProcessor): Processor for the SQuAD data set. Overriden by SquadV1Processor and SquadV2Processor, used by the version 1.1 and version 2.0 of SQuAD, respectively. """ + train_file = None dev_file = None def _get_example_from_tensor_dict(self, tensor_dict, evaluate=False): if not evaluate: - answer = tensor_dict['answers']['text'][0].numpy().decode('utf-8') - answer_start = tensor_dict['answers']['answer_start'][0].numpy() + answer = tensor_dict["answers"]["text"][0].numpy().decode("utf-8") + answer_start = tensor_dict["answers"]["answer_start"][0].numpy() answers = [] else: - answers = [{ - "answer_start": start.numpy(), - "text": text.numpy().decode('utf-8') - } for start, text in zip(tensor_dict['answers']["answer_start"], tensor_dict['answers']["text"])] + answers = [ + {"answer_start": start.numpy(), "text": text.numpy().decode("utf-8")} + for start, text in zip(tensor_dict["answers"]["answer_start"], tensor_dict["answers"]["text"]) + ] answer = None answer_start = None return SquadExample( - qas_id=tensor_dict['id'].numpy().decode("utf-8"), - question_text=tensor_dict['question'].numpy().decode('utf-8'), - context_text=tensor_dict['context'].numpy().decode('utf-8'), + qas_id=tensor_dict["id"].numpy().decode("utf-8"), + question_text=tensor_dict["question"].numpy().decode("utf-8"), + context_text=tensor_dict["context"].numpy().decode("utf-8"), answer_text=answer, start_position_character=answer_start, - title=tensor_dict['title'].numpy().decode('utf-8'), - answers=answers + title=tensor_dict["title"].numpy().decode("utf-8"), + answers=answers, ) def get_examples_from_dataset(self, dataset, evaluate=False): @@ -359,7 +414,7 @@ class SquadProcessor(DataProcessor): examples = [] for tensor_dict in tqdm(dataset): - examples.append(self._get_example_from_tensor_dict(tensor_dict, evaluate=evaluate)) + examples.append(self._get_example_from_tensor_dict(tensor_dict, evaluate=evaluate)) return examples @@ -379,7 +434,9 @@ class SquadProcessor(DataProcessor): if self.train_file is None: raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor") - with open(os.path.join(data_dir, self.train_file if filename is None else filename), "r", encoding='utf-8') as reader: + with open( + os.path.join(data_dir, self.train_file if filename is None else filename), "r", encoding="utf-8" + ) as reader: input_data = json.load(reader)["data"] return self._create_examples(input_data, "train") @@ -397,8 +454,10 @@ class SquadProcessor(DataProcessor): if self.dev_file is None: raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor") - - with open(os.path.join(data_dir, self.dev_file if filename is None else filename), "r", encoding='utf-8') as reader: + + with open( + os.path.join(data_dir, self.dev_file if filename is None else filename), "r", encoding="utf-8" + ) as reader: input_data = json.load(reader)["data"] return self._create_examples(input_data, "dev") @@ -406,7 +465,7 @@ class SquadProcessor(DataProcessor): is_training = set_type == "train" examples = [] for entry in tqdm(input_data): - title = entry['title'] + title = entry["title"] for paragraph in entry["paragraphs"]: context_text = paragraph["context"] for qa in paragraph["qas"]: @@ -415,7 +474,7 @@ class SquadProcessor(DataProcessor): start_position_character = None answer_text = None answers = [] - + if "is_impossible" in qa: is_impossible = qa["is_impossible"] else: @@ -424,8 +483,8 @@ class SquadProcessor(DataProcessor): if not is_impossible: if is_training: answer = qa["answers"][0] - answer_text = answer['text'] - start_position_character = answer['answer_start'] + answer_text = answer["text"] + start_position_character = answer["answer_start"] else: answers = qa["answers"] @@ -437,12 +496,13 @@ class SquadProcessor(DataProcessor): start_position_character=start_position_character, title=title, is_impossible=is_impossible, - answers=answers + answers=answers, ) examples.append(example) return examples + class SquadV1Processor(SquadProcessor): train_file = "train-v1.1.json" dev_file = "dev-v1.1.json" @@ -451,7 +511,7 @@ class SquadV1Processor(SquadProcessor): class SquadV2Processor(SquadProcessor): train_file = "train-v2.0.json" dev_file = "dev-v2.0.json" - + class SquadExample(object): """ @@ -468,21 +528,23 @@ class SquadExample(object): is_impossible: False by default, set to True if the example has no possible answer. """ - def __init__(self, - qas_id, - question_text, - context_text, - answer_text, - start_position_character, - title, - answers=[], - is_impossible=False): + def __init__( + self, + qas_id, + question_text, + context_text, + answer_text, + start_position_character, + title, + answers=[], + is_impossible=False, + ): self.qas_id = qas_id self.question_text = question_text self.context_text = context_text self.answer_text = answer_text self.title = title - self.is_impossible = is_impossible + self.is_impossible = is_impossible self.answers = answers self.start_position, self.end_position = 0, 0 @@ -537,24 +599,23 @@ class SquadFeatures(object): end_position: end of the answer token index """ - def __init__(self, - input_ids, - attention_mask, - token_type_ids, - cls_index, - p_mask, - - example_index, - unique_id, - paragraph_len, - token_is_max_context, - tokens, - token_to_orig_map, - - start_position, - end_position - ): - self.input_ids = input_ids + def __init__( + self, + input_ids, + attention_mask, + token_type_ids, + cls_index, + p_mask, + example_index, + unique_id, + paragraph_len, + token_is_max_context, + tokens, + token_to_orig_map, + start_position, + end_position, + ): + self.input_ids = input_ids self.attention_mask = attention_mask self.token_type_ids = token_type_ids self.cls_index = cls_index @@ -580,12 +641,13 @@ class SquadResult(object): start_logits: The logits corresponding to the start of the answer end_logits: The logits corresponding to the end of the answer """ + def __init__(self, unique_id, start_logits, end_logits, start_top_index=None, end_top_index=None, cls_logits=None): self.start_logits = start_logits self.end_logits = end_logits self.unique_id = unique_id - + if start_top_index: self.start_top_index = start_top_index self.end_top_index = end_top_index - self.cls_logits = cls_logits \ No newline at end of file + self.cls_logits = cls_logits From 866d73ca26a13d7e378b2f88f365cb0807c47805 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Fri, 13 Dec 2019 16:09:23 -0500 Subject: [PATCH 05/12] [cli] Upload is now compatible with folders --- transformers/commands/user.py | 57 ++++++++++++++++++++++++++--------- 1 file changed, 43 insertions(+), 14 deletions(-) diff --git a/transformers/commands/user.py b/transformers/commands/user.py index d79922ed8ac..8e0e5634223 100644 --- a/transformers/commands/user.py +++ b/transformers/commands/user.py @@ -19,8 +19,8 @@ class UserCommands(BaseTransformersCLICommand): list_parser.set_defaults(func=lambda args: ListObjsCommand(args)) # upload upload_parser = parser.add_parser('upload') - upload_parser.add_argument('file', type=str, help='Local filepath of the file to upload.') - upload_parser.add_argument('--filename', type=str, default=None, help='Optional: override object filename on S3.') + upload_parser.add_argument('path', type=str, help='Local path of the folder or individual file to upload.') + upload_parser.add_argument('--filename', type=str, default=None, help='Optional: override individual object filename on S3.') upload_parser.set_defaults(func=lambda args: UploadCommand(args)) @@ -138,28 +138,57 @@ class ListObjsCommand(BaseUserCommand): class UploadCommand(BaseUserCommand): + def walk_dir(self, rel_path): + """ + Recursively list all files in a folder. + """ + entries: List[os.DirEntry] = list(os.scandir(rel_path)) + files = [ + ( + os.path.join(os.getcwd(), f.path), # filepath + f.path # filename + ) + for f in entries if f.is_file() + ] + for f in entries: + if f.is_dir(): + files += self.walk_dir(f.path) + return files + def run(self): token = HfFolder.get_token() if token is None: print("Not logged in") exit(1) - filepath = os.path.join(os.getcwd(), self.args.file) - filename = self.args.filename if self.args.filename is not None else os.path.basename(filepath) - print( - "About to upload file {} to S3 under filename {}".format( - ANSI.bold(filepath), ANSI.bold(filename) + local_path = os.path.abspath(self.args.path) + if os.path.isdir(local_path): + if self.args.filename is not None: + raise ValueError("Cannot specify a filename override when uploading a folder.") + rel_path = os.path.basename(local_path) + files = self.walk_dir(rel_path) + elif os.path.isfile(local_path): + filename = self.args.filename if self.args.filename is not None else os.path.basename(local_path) + files = [(local_path, filename)] + else: + raise ValueError("Not a valid file or directory: {}".format(local_path)) + + for filepath, filename in files: + print( + "About to upload file {} to S3 under filename {}".format( + ANSI.bold(filepath), ANSI.bold(filename) + ) ) - ) choice = input("Proceed? [Y/n] ").lower() if not(choice == "" or choice == "y" or choice == "yes"): print("Abort") exit() print( - ANSI.bold("Uploading... This might take a while if file is large") + ANSI.bold("Uploading... This might take a while if files are large") ) - access_url = self._api.presign_and_upload( - token=token, filename=filename, filepath=filepath - ) - print("Your file now lives at:") - print(access_url) + for filepath, filename in files: + access_url = self._api.presign_and_upload( + token=token, filename=filename, filepath=filepath + ) + print("Your file now lives at:") + print(access_url) From 5b7b78e088352a3aaf1f80d26bb1cd466bc2ac64 Mon Sep 17 00:00:00 2001 From: Pascal Voitot Date: Sun, 8 Dec 2019 23:22:02 +0100 Subject: [PATCH 06/12] :bug: #2096 in tokenizer.decode, adds a space after special tokens to return right formatted string --- transformers/tokenization_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformers/tokenization_utils.py b/transformers/tokenization_utils.py index e87c87787b9..42519c26ba5 100644 --- a/transformers/tokenization_utils.py +++ b/transformers/tokenization_utils.py @@ -1180,7 +1180,7 @@ class PreTrainedTokenizer(object): if current_sub_text: sub_texts.append(self.convert_tokens_to_string(current_sub_text)) current_sub_text = [] - sub_texts.append(" " + token) + sub_texts.append(" " + token + " ") else: current_sub_text.append(token) if current_sub_text: From df160af736cba1d50c09abcf92c8fc6c00bcdb13 Mon Sep 17 00:00:00 2001 From: Pascal Voitot Date: Tue, 10 Dec 2019 00:03:38 +0100 Subject: [PATCH 07/12] :bug: #2096 in tokenizer.decode, space is not joined between all subtexts instead of before added tokens --- transformers/tests/tokenization_bert_test.py | 16 ++++++++++++++++ transformers/tokenization_utils.py | 4 ++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/transformers/tests/tokenization_bert_test.py b/transformers/tests/tokenization_bert_test.py index f3902489565..c47f149e9ae 100644 --- a/transformers/tests/tokenization_bert_test.py +++ b/transformers/tests/tokenization_bert_test.py @@ -99,6 +99,21 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): self.assertListEqual( tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) + def test_encode_decode_with_spaces(self): + tokenizer = self.get_tokenizer() + + new_toks = ['[ABC]', '[DEF]', 'GHI IHG'] + tokenizer.add_tokens(new_toks) + input = "unwanted running [ABC] [DEF] running unwanted [ABC] GHI IHG unwanted [DEF]" + encoded = tokenizer.encode(input) + decoded = tokenizer.decode(encoded) + self.assertEqual( + decoded.lower(), + (f"[CLS] {input.lower()} [SEP]").lower() + ) + + + def test_is_whitespace(self): self.assertTrue(_is_whitespace(u" ")) self.assertTrue(_is_whitespace(u"\t")) @@ -139,5 +154,6 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): assert encoded_sentence == [101] + text + [102] assert encoded_pair == [101] + text + [102] + text_2 + [102] + if __name__ == '__main__': unittest.main() diff --git a/transformers/tokenization_utils.py b/transformers/tokenization_utils.py index 42519c26ba5..8aef80fec89 100644 --- a/transformers/tokenization_utils.py +++ b/transformers/tokenization_utils.py @@ -1180,12 +1180,12 @@ class PreTrainedTokenizer(object): if current_sub_text: sub_texts.append(self.convert_tokens_to_string(current_sub_text)) current_sub_text = [] - sub_texts.append(" " + token + " ") + sub_texts.append(token) else: current_sub_text.append(token) if current_sub_text: sub_texts.append(self.convert_tokens_to_string(current_sub_text)) - text = ''.join(sub_texts) + text = ' '.join(sub_texts) if clean_up_tokenization_spaces: clean_text = self.clean_up_tokenization(text) From dd2add9f6efdaa248f3074b865dc67c439b30a4d Mon Sep 17 00:00:00 2001 From: Pascal Voitot Date: Tue, 10 Dec 2019 00:29:44 +0100 Subject: [PATCH 08/12] more tests --- transformers/tests/tokenization_bert_test.py | 2 +- transformers/tests/tokenization_gpt2_test.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/transformers/tests/tokenization_bert_test.py b/transformers/tests/tokenization_bert_test.py index c47f149e9ae..b93934dd676 100644 --- a/transformers/tests/tokenization_bert_test.py +++ b/transformers/tests/tokenization_bert_test.py @@ -109,7 +109,7 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): decoded = tokenizer.decode(encoded) self.assertEqual( decoded.lower(), - (f"[CLS] {input.lower()} [SEP]").lower() + (f"[CLS] {input} [SEP]").lower() ) diff --git a/transformers/tests/tokenization_gpt2_test.py b/transformers/tests/tokenization_gpt2_test.py index a77cc75ec2e..9e6ca3c4fda 100644 --- a/transformers/tests/tokenization_gpt2_test.py +++ b/transformers/tests/tokenization_gpt2_test.py @@ -67,6 +67,20 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester): self.assertListEqual( tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) + def test_encode_decode_with_spaces(self): + tokenizer = self.get_tokenizer() + + new_toks = ['[ABC]', '[DEF]', 'GHI IHG'] + tokenizer.add_tokens(new_toks) + input = "lower newer [ABC] [DEF] newer lower [ABC] GHI IHG newer lower[DEF]" + encoded = tokenizer.encode(input) + decoded = tokenizer.decode(encoded) + self.assertEqual( + decoded.lower(), + input.lower() + ) + + if __name__ == '__main__': unittest.main() From 4cbdc7d910a0a12871a8e29760a3a6721a138421 Mon Sep 17 00:00:00 2001 From: Pascal Voitot Date: Tue, 10 Dec 2019 09:37:15 +0100 Subject: [PATCH 09/12] missed space --- transformers/tests/tokenization_bert_test.py | 2 -- transformers/tests/tokenization_gpt2_test.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/transformers/tests/tokenization_bert_test.py b/transformers/tests/tokenization_bert_test.py index b93934dd676..a039a24dd8d 100644 --- a/transformers/tests/tokenization_bert_test.py +++ b/transformers/tests/tokenization_bert_test.py @@ -112,8 +112,6 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): (f"[CLS] {input} [SEP]").lower() ) - - def test_is_whitespace(self): self.assertTrue(_is_whitespace(u" ")) self.assertTrue(_is_whitespace(u"\t")) diff --git a/transformers/tests/tokenization_gpt2_test.py b/transformers/tests/tokenization_gpt2_test.py index 9e6ca3c4fda..1b4fe428746 100644 --- a/transformers/tests/tokenization_gpt2_test.py +++ b/transformers/tests/tokenization_gpt2_test.py @@ -72,7 +72,7 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester): new_toks = ['[ABC]', '[DEF]', 'GHI IHG'] tokenizer.add_tokens(new_toks) - input = "lower newer [ABC] [DEF] newer lower [ABC] GHI IHG newer lower[DEF]" + input = "lower newer [ABC] [DEF] newer lower [ABC] GHI IHG newer lower [DEF]" encoded = tokenizer.encode(input) decoded = tokenizer.decode(encoded) self.assertEqual( From f2ac50cb5560e13d941f1ea3dec3399f12f7a3fb Mon Sep 17 00:00:00 2001 From: Pascal Voitot Date: Tue, 10 Dec 2019 09:58:06 +0100 Subject: [PATCH 10/12] better for python2.x --- transformers/tests/tokenization_bert_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformers/tests/tokenization_bert_test.py b/transformers/tests/tokenization_bert_test.py index a039a24dd8d..77b124cdf2c 100644 --- a/transformers/tests/tokenization_bert_test.py +++ b/transformers/tests/tokenization_bert_test.py @@ -109,7 +109,7 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): decoded = tokenizer.decode(encoded) self.assertEqual( decoded.lower(), - (f"[CLS] {input} [SEP]").lower() + ("[CLS] " + input + " [SEP]").lower() ) def test_is_whitespace(self): From c3248cf122014dce10c0c8d0e663a95c948493e3 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Wed, 11 Dec 2019 12:36:37 -0500 Subject: [PATCH 11/12] Tests for all tokenizers --- transformers/tests/tokenization_bert_test.py | 13 ------------- transformers/tests/tokenization_gpt2_test.py | 15 --------------- transformers/tests/tokenization_tests_commons.py | 9 +++++++++ 3 files changed, 9 insertions(+), 28 deletions(-) diff --git a/transformers/tests/tokenization_bert_test.py b/transformers/tests/tokenization_bert_test.py index 77b124cdf2c..c503ea5e1e4 100644 --- a/transformers/tests/tokenization_bert_test.py +++ b/transformers/tests/tokenization_bert_test.py @@ -99,19 +99,6 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): self.assertListEqual( tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) - def test_encode_decode_with_spaces(self): - tokenizer = self.get_tokenizer() - - new_toks = ['[ABC]', '[DEF]', 'GHI IHG'] - tokenizer.add_tokens(new_toks) - input = "unwanted running [ABC] [DEF] running unwanted [ABC] GHI IHG unwanted [DEF]" - encoded = tokenizer.encode(input) - decoded = tokenizer.decode(encoded) - self.assertEqual( - decoded.lower(), - ("[CLS] " + input + " [SEP]").lower() - ) - def test_is_whitespace(self): self.assertTrue(_is_whitespace(u" ")) self.assertTrue(_is_whitespace(u"\t")) diff --git a/transformers/tests/tokenization_gpt2_test.py b/transformers/tests/tokenization_gpt2_test.py index 1b4fe428746..5eae767bdfc 100644 --- a/transformers/tests/tokenization_gpt2_test.py +++ b/transformers/tests/tokenization_gpt2_test.py @@ -67,20 +67,5 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester): self.assertListEqual( tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) - def test_encode_decode_with_spaces(self): - tokenizer = self.get_tokenizer() - - new_toks = ['[ABC]', '[DEF]', 'GHI IHG'] - tokenizer.add_tokens(new_toks) - input = "lower newer [ABC] [DEF] newer lower [ABC] GHI IHG newer lower [DEF]" - encoded = tokenizer.encode(input) - decoded = tokenizer.decode(encoded) - self.assertEqual( - decoded.lower(), - input.lower() - ) - - - if __name__ == '__main__': unittest.main() diff --git a/transformers/tests/tokenization_tests_commons.py b/transformers/tests/tokenization_tests_commons.py index c0099581357..13e7ae746af 100644 --- a/transformers/tests/tokenization_tests_commons.py +++ b/transformers/tests/tokenization_tests_commons.py @@ -232,6 +232,15 @@ class CommonTestCases: self.assertNotEqual(len(tokens_2), 0) self.assertIsInstance(text_2, (str, unicode)) + def test_encode_decode_with_spaces(self): + tokenizer = self.get_tokenizer() + + new_toks = ['[ABC]', '[DEF]', 'GHI IHG'] + tokenizer.add_tokens(new_toks) + input = "[ABC] [DEF] [ABC] GHI IHG [DEF]" + encoded = tokenizer.encode(input, add_special_tokens=False) + decoded = tokenizer.decode(encoded) + self.assertEqual(decoded, input) def test_pretrained_model_lists(self): weights_list = list(self.tokenizer_class.max_model_input_sizes.keys()) From 7bd11dda6f43656cf0a3891b7f61a67196d233b4 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Fri, 13 Dec 2019 16:45:30 -0500 Subject: [PATCH 12/12] Release: v2.2.2 --- README.md | 2 +- docs/source/conf.py | 2 +- setup.py | 2 +- transformers/__init__.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index f3aa8a95ee2..f24ceaa6d23 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ Choose the right framework for every part of a model's lifetime | [Quick tour: Fine-tuning/usage scripts](#quick-tour-of-the-fine-tuningusage-scripts) | Using provided scripts: GLUE, SQuAD and Text generation | | [Migrating from pytorch-transformers to transformers](#Migrating-from-pytorch-transformers-to-transformers) | Migrating your code from pytorch-transformers to transformers | | [Migrating from pytorch-pretrained-bert to pytorch-transformers](#Migrating-from-pytorch-pretrained-bert-to-transformers) | Migrating your code from pytorch-pretrained-bert to transformers | -| [Documentation][(v2.2.0/v2.2.1)](https://huggingface.co/transformers/v2.2.0) [(v2.1.1)](https://huggingface.co/transformers/v2.1.1) [(v2.0.0)](https://huggingface.co/transformers/v2.0.0) [(v1.2.0)](https://huggingface.co/transformers/v1.2.0) [(v1.1.0)](https://huggingface.co/transformers/v1.1.0) [(v1.0.0)](https://huggingface.co/transformers/v1.0.0) [(master)](https://huggingface.co/transformers) | Full API documentation and more | +| [Documentation][(v2.2.0/v2.2.1/v2.2.2)](https://huggingface.co/transformers/v2.2.0) [(v2.1.1)](https://huggingface.co/transformers/v2.1.1) [(v2.0.0)](https://huggingface.co/transformers/v2.0.0) [(v1.2.0)](https://huggingface.co/transformers/v1.2.0) [(v1.1.0)](https://huggingface.co/transformers/v1.1.0) [(v1.0.0)](https://huggingface.co/transformers/v1.0.0) [(master)](https://huggingface.co/transformers) | Full API documentation and more | ## Installation diff --git a/docs/source/conf.py b/docs/source/conf.py index 2f8505ab3a7..99b7b449220 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -26,7 +26,7 @@ author = u'huggingface' # The short X.Y version version = u'' # The full version, including alpha/beta/rc tags -release = u'2.2.1' +release = u'2.2.2' # -- General configuration --------------------------------------------------- diff --git a/setup.py b/setup.py index c4af32df83a..eacb5ecec0d 100644 --- a/setup.py +++ b/setup.py @@ -44,7 +44,7 @@ extras['all'] = [package for package in extras.values()] setup( name="transformers", - version="2.2.1", + version="2.2.2", author="Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Google AI Language Team Authors, Open AI team Authors, Facebook AI Authors, Carnegie Mellon University Authors", author_email="thomas@huggingface.co", description="State-of-the-art Natural Language Processing for TensorFlow 2.0 and PyTorch", diff --git a/transformers/__init__.py b/transformers/__init__.py index 5d7b0b772cb..c11919f0a75 100644 --- a/transformers/__init__.py +++ b/transformers/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.2.1" +__version__ = "2.2.2" # Work around to update TensorFlow's absl.logging threshold which alters the # default Python logging output behavior when present.