[Json configs] Make json prettier for all saved tokenizer files & ensure same json format for all processors (tok + feat_extract) (#17457)

* [Json dump] Make json prettier

* correct more tokenizeirs

* more patterns

* add aggressive test

* the aggressive test was actually useful :-)

* more tests

* Apply suggestions from code review
This commit is contained in:
Patrick von Platen 2022-05-31 17:07:30 +02:00 committed by GitHub
parent 6ee1474b67
commit f394a2a50d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 52 additions and 24 deletions

View File

@ -354,6 +354,8 @@ class FeatureExtractionMixin(PushToHubMixin):
url = self._push_to_hub(repo, commit_message=commit_message)
logger.info(f"Feature extractor pushed to the hub in this commit: {url}")
return [output_feature_extractor_file]
@classmethod
def get_feature_extractor_dict(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs

View File

@ -318,7 +318,7 @@ class BartTokenizer(PreTrainedTokenizer):
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:

View File

@ -220,7 +220,7 @@ class BlenderbotSmallTokenizer(PreTrainedTokenizer):
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:

View File

@ -345,7 +345,7 @@ class CLIPTokenizer(PreTrainedTokenizer):
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:

View File

@ -236,7 +236,7 @@ class CTRLTokenizer(PreTrainedTokenizer):
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:

View File

@ -505,11 +505,11 @@ class FSMTTokenizer(PreTrainedTokenizer):
)
with open(src_vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
with open(tgt_vocab_file, "w", encoding="utf-8") as f:
tgt_vocab = {v: k for k, v in self.decoder.items()}
f.write(json.dumps(tgt_vocab, ensure_ascii=False))
f.write(json.dumps(tgt_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merges_file, "w", encoding="utf-8") as writer:

View File

@ -297,7 +297,7 @@ class GPT2Tokenizer(PreTrainedTokenizer):
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:

View File

@ -437,7 +437,7 @@ class LayoutLMv3Tokenizer(PreTrainedTokenizer):
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:

View File

@ -1386,6 +1386,6 @@ class LukeTokenizer(RobertaTokenizer):
)
with open(entity_vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.entity_vocab, ensure_ascii=False))
f.write(json.dumps(self.entity_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
return vocab_file, merge_file, entity_vocab_file

View File

@ -1509,7 +1509,7 @@ class MLukeTokenizer(PreTrainedTokenizer):
)
with open(entity_vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.entity_vocab, ensure_ascii=False))
f.write(json.dumps(self.entity_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
return out_vocab_file, entity_vocab_file

View File

@ -215,7 +215,7 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:

View File

@ -324,7 +324,7 @@ class RobertaTokenizer(PreTrainedTokenizer):
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:

View File

@ -250,7 +250,7 @@ class Speech2Text2Tokenizer(PreTrainedTokenizer):
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
if self.bpe_ranks is None:

View File

@ -503,7 +503,7 @@ class TapexTokenizer(PreTrainedTokenizer):
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:

View File

@ -603,7 +603,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
return (vocab_file,)
@ -922,6 +922,6 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
return (vocab_file,)

View File

@ -568,7 +568,7 @@ class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer):
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
return (vocab_file,)

View File

@ -965,7 +965,7 @@ class XLMTokenizer(PreTrainedTokenizer):
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:

View File

@ -1494,3 +1494,20 @@ def nested_simplify(obj, decimals=3):
return nested_simplify(obj.item(), decimals)
else:
raise Exception(f"Not supported: {type(obj)}")
def check_json_file_has_correct_format(file_path):
with open(file_path, "r") as f:
lines = f.readlines()
if len(lines) == 1:
# length can only be 1 if dict is empty
assert lines[0] == "{}"
else:
# otherwise make sure json has correct format (at least 3 lines)
assert len(lines) >= 3
# each key one line, ident should be 2, min length is 3
assert lines[0].strip() == "{"
for line in lines[1:-1]:
left_indent = len(lines[1]) - len(lines[1].lstrip())
assert left_indent == 2
assert lines[-1].strip() == "}"

View File

@ -2118,13 +2118,15 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
custom_object_save(self, save_directory, config=tokenizer_config)
with open(tokenizer_config_file, "w", encoding="utf-8") as f:
f.write(json.dumps(tokenizer_config, ensure_ascii=False))
out_str = json.dumps(tokenizer_config, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
f.write(out_str)
logger.info(f"tokenizer config file saved in {tokenizer_config_file}")
# Sanitize AddedTokens in special_tokens_map
write_dict = convert_added_tokens(self.special_tokens_map_extended, add_type_field=False)
with open(special_tokens_map_file, "w", encoding="utf-8") as f:
f.write(json.dumps(write_dict, ensure_ascii=False))
out_str = json.dumps(write_dict, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
f.write(out_str)
logger.info(f"Special tokens file saved in {special_tokens_map_file}")
file_names = (tokenizer_config_file, special_tokens_map_file)
@ -2168,7 +2170,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
added_vocab = self.get_added_vocab()
if added_vocab:
with open(added_tokens_file, "w", encoding="utf-8") as f:
out_str = json.dumps(added_vocab, ensure_ascii=False)
out_str = json.dumps(added_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
f.write(out_str)
logger.info(f"added tokens file saved in {added_tokens_file}")

View File

@ -585,7 +585,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
added_vocab = self.get_added_vocab()
if added_vocab:
with open(added_tokens_file, "w", encoding="utf-8") as f:
out_str = json.dumps(added_vocab, ensure_ascii=False)
out_str = json.dumps(added_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
f.write(out_str)
vocab_files = self.save_vocabulary(save_directory, filename_prefix=filename_prefix)

View File

@ -25,7 +25,7 @@ from pathlib import Path
from huggingface_hub import Repository, delete_repo, login
from requests.exceptions import HTTPError
from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor
from transformers.testing_utils import PASS, USER, get_tests_dir, is_staging_test
from transformers.testing_utils import PASS, USER, check_json_file_has_correct_format, get_tests_dir, is_staging_test
from transformers.utils import is_torch_available, is_vision_available
@ -107,7 +107,8 @@ class FeatureExtractionSavingTestMixin:
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
feat_extract_first.save_pretrained(tmpdirname)
saved_file = feat_extract_first.save_pretrained(tmpdirname)[0]
check_json_file_has_correct_format(saved_file)
feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname)
self.assertEqual(feat_extract_second.to_dict(), feat_extract_first.to_dict())

View File

@ -51,6 +51,7 @@ from transformers import (
from transformers.testing_utils import (
PASS,
USER,
check_json_file_has_correct_format,
get_tests_dir,
is_pt_tf_cross_test,
is_staging_test,
@ -3325,6 +3326,11 @@ class TokenizerTesterMixin:
tokenizer_r_files = tokenizer_r.save_pretrained(tmpdirname2)
tokenizer_p_files = tokenizer_p.save_pretrained(tmpdirname2)
# make sure that all ".json" files are saved in the correct format
for file_path in tokenizer_r_files + tokenizer_p_files:
if os.path.exists(file_path) and file_path.endswith(".json"):
check_json_file_has_correct_format(file_path)
# Checks it save with the same files + the tokenizer.json file for the fast one
self.assertTrue(any("tokenizer.json" in f for f in tokenizer_r_files))
tokenizer_r_files = tuple(f for f in tokenizer_r_files if "tokenizer.json" not in f)