diff --git a/examples/run_transfo_xl.py b/examples/run_transfo_xl.py index 8139f28baf5..0ea7b320536 100644 --- a/examples/run_transfo_xl.py +++ b/examples/run_transfo_xl.py @@ -28,7 +28,7 @@ import math import torch -from pytorch_pretrained_bert import TransfoXLLMHeadModel, TransfoXLCorpus +from pytorch_pretrained_bert import TransfoXLLMHeadModel, TransfoXLCorpus, TransfoXLTokenizer logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt = '%m/%d/%Y %H:%M:%S', @@ -80,6 +80,7 @@ def main(): # The pre-processing involve computing word frequencies to prepare the Adaptive input and SoftMax # and tokenizing the dataset # The pre-processed corpus is a convertion (using the conversion script ) + tokenizer = TransfoXLTokenizer.from_pretrained(args.model_name) corpus = TransfoXLCorpus.from_pretrained(args.model_name) ntokens = len(corpus.vocab) diff --git a/pytorch_pretrained_bert/tokenization.py b/pytorch_pretrained_bert/tokenization.py index bbb3e25fc79..6e2e11ed921 100644 --- a/pytorch_pretrained_bert/tokenization.py +++ b/pytorch_pretrained_bert/tokenization.py @@ -134,6 +134,19 @@ class BertTokenizer(object): tokens.append(self.ids_to_tokens[i]) return tokens + def save_vocabulary(self, vocab_path): + """Save the tokenizer vocabulary to a path.""" + index = 0 + vocab_file = os.path.join(vocab_path, VOCAB_NAME) + with open(vocab_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!".format(vocab_file)) + index = token_index + writer.write(token + u'\n') + index += 1 + @classmethod def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): """ diff --git a/pytorch_pretrained_bert/tokenization_gpt2.py b/pytorch_pretrained_bert/tokenization_gpt2.py index db95719dbcc..07db995b968 100644 --- a/pytorch_pretrained_bert/tokenization_gpt2.py +++ b/pytorch_pretrained_bert/tokenization_gpt2.py @@ -187,6 +187,22 @@ class GPT2Tokenizer(object): self.cache[token] = word return word + def save_vocabulary(self, vocab_path): + """Save the tokenizer vocabulary to a path.""" + vocab_file = os.path.join(vocab_path, VOCAB_NAME) + merge_file = os.path.join(vocab_path, MERGES_NAME) + json.dump(self.encoder, vocab_file) + index = 0 + with open(merge_file, "w", encoding="utf-8") as writer: + writer.write(u'#version: 0.2\n') + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!".format(merge_file)) + index = token_index + writer.write(bpe_tokens + u'\n') + index += 1 + def encode(self, text): bpe_tokens = [] for token in re.findall(self.pat, text): diff --git a/pytorch_pretrained_bert/tokenization_openai.py b/pytorch_pretrained_bert/tokenization_openai.py index 240122d12df..aa0438ccf80 100644 --- a/pytorch_pretrained_bert/tokenization_openai.py +++ b/pytorch_pretrained_bert/tokenization_openai.py @@ -261,3 +261,19 @@ class OpenAIGPTTokenizer(object): ).replace(" 's", "'s").replace(" t ", "'t ").replace(" s ", "'s ").replace(" m ", "'m " ).replace(" 've", "'ve") return out_string + + def save_vocabulary(self, vocab_path): + """Save the tokenizer vocabulary to a path.""" + vocab_file = os.path.join(vocab_path, VOCAB_NAME) + merge_file = os.path.join(vocab_path, MERGES_NAME) + json.dump(self.encoder, vocab_file) + index = 0 + with open(merge_file, "w", encoding="utf-8") as writer: + writer.write(u'#version: 0.2\n') + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!".format(merge_file)) + index = token_index + writer.write(bpe_tokens + u'\n') + index += 1 diff --git a/pytorch_pretrained_bert/tokenization_transfo_xl.py b/pytorch_pretrained_bert/tokenization_transfo_xl.py index b5360c51843..b6470c76670 100644 --- a/pytorch_pretrained_bert/tokenization_transfo_xl.py +++ b/pytorch_pretrained_bert/tokenization_transfo_xl.py @@ -63,7 +63,10 @@ class TransfoXLTokenizer(object): if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] else: - vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) + if os.path.isdir(pretrained_model_name_or_path): + vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) + else: + vocab_file = pretrained_model_name_or_path # redirect to the cache, if necessary try: resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) @@ -141,6 +144,11 @@ class TransfoXLTokenizer(object): else: raise ValueError('No token in vocabulary') + def save_vocabulary(self, vocab_path): + index = 0 + vocab_file = os.path.join(vocab_path, VOCAB_NAME) + torch.save(self.__dict__, vocab_file) + def build_vocab(self): if self.vocab_file: print('building vocab from {}'.format(self.vocab_file)) @@ -245,82 +253,24 @@ class TransfoXLTokenizer(object): def __len__(self): return len(self.idx2sym) - def _run_split_on_punc(self, text): - """Splits punctuation on a piece of text.""" - if text in self.never_split: - return [text] - chars = list(text) - i = 0 - start_new_word = True - output = [] - while i < len(chars): - char = chars[i] - if _is_punctuation(char): - output.append([char]) - start_new_word = True - else: - if start_new_word: - output.append([]) - start_new_word = False - output[-1].append(char) - i += 1 - - return ["".join(x) for x in output] - - def _run_strip_accents(self, text): - """Strips accents from a piece of text.""" - text = unicodedata.normalize("NFD", text) - output = [] - for char in text: - cat = unicodedata.category(char) - if cat == "Mn": - continue - output.append(char) - return "".join(output) - - def _clean_text(self, text): - """Performs invalid character removal and whitespace cleanup on text.""" - output = [] - for char in text: - cp = ord(char) - if cp == 0 or cp == 0xfffd or _is_control(char): - continue - if _is_whitespace(char): - output.append(" ") - else: - output.append(char) - return "".join(output) - - def whitespace_tokenize(self, text): - """Runs basic whitespace cleaning and splitting on a piece of text.""" - text = text.strip() - if not text: - return [] - if self.delimiter == '': - tokens = text - else: - tokens = text.split(self.delimiter) - return tokens - def tokenize(self, line, add_eos=False, add_double_eos=False): - line = self._clean_text(line) line = line.strip() + # convert to lower case + if self.lower_case: + line = line.lower() - symbols = self.whitespace_tokenize(line) - - split_symbols = [] - for symbol in symbols: - if self.lower_case and symbol not in self.never_split: - symbol = symbol.lower() - symbol = self._run_strip_accents(symbol) - split_symbols.extend(self._run_split_on_punc(symbol)) + # empty delimiter '' will evaluate False + if self.delimiter == '': + symbols = line + else: + symbols = line.split(self.delimiter) if add_double_eos: # lm1b - return [''] + split_symbols + [''] + return [''] + symbols + [''] elif add_eos: - return split_symbols + [''] + return symbols + [''] else: - return split_symbols + return symbols class LMOrderedIterator(object): @@ -631,42 +581,3 @@ def get_lm_corpus(datadir, dataset): torch.save(corpus, fn) return corpus - -def _is_whitespace(char): - """Checks whether `chars` is a whitespace character.""" - # \t, \n, and \r are technically contorl characters but we treat them - # as whitespace since they are generally considered as such. - if char == " " or char == "\t" or char == "\n" or char == "\r": - return True - cat = unicodedata.category(char) - if cat == "Zs": - return True - return False - - -def _is_control(char): - """Checks whether `chars` is a control character.""" - # These are technically control characters but we count them as whitespace - # characters. - if char == "\t" or char == "\n" or char == "\r": - return False - cat = unicodedata.category(char) - if cat.startswith("C"): - return True - return False - - -def _is_punctuation(char): - """Checks whether `chars` is a punctuation character.""" - cp = ord(char) - # We treat all non-letter/number ASCII as punctuation. - # Characters such as "^", "$", and "`" are not in the Unicode - # Punctuation class but we treat them as punctuation anyways, for - # consistency. - if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or - (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): - return True - cat = unicodedata.category(char) - if cat.startswith("P"): - return True - return False