mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Add possibility to maintain full copies of files (#12312)
This commit is contained in:
parent
9490d668d2
commit
57461ac0b4
@ -38,6 +38,7 @@ def postprocess_qa_predictions(
|
||||
null_score_diff_threshold: float = 0.0,
|
||||
output_dir: Optional[str] = None,
|
||||
prefix: Optional[str] = None,
|
||||
is_world_process_zero: bool = True,
|
||||
):
|
||||
"""
|
||||
Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the
|
||||
@ -90,6 +91,7 @@ def postprocess_qa_predictions(
|
||||
scores_diff_json = collections.OrderedDict()
|
||||
|
||||
# Logging.
|
||||
logger.setLevel(logging.INFO if is_world_process_zero else logging.WARN)
|
||||
logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
|
||||
|
||||
# Let's loop over all the examples!
|
||||
|
@ -27,6 +27,9 @@ TRANSFORMERS_PATH = "src/transformers"
|
||||
PATH_TO_DOCS = "docs/source"
|
||||
REPO_PATH = "."
|
||||
|
||||
# Mapping for files that are full copies of others (keys are copies, values the file to keep them up to data with)
|
||||
FULL_COPIES = {"examples/tensorflow/question-answering/utils_qa.py": "examples/pytorch/question-answering/utils_qa.py"}
|
||||
|
||||
|
||||
def _should_continue(line, indent):
|
||||
return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\):\s*$", line) is not None
|
||||
@ -192,6 +195,30 @@ def check_copies(overwrite: bool = False):
|
||||
check_model_list_copy(overwrite=overwrite)
|
||||
|
||||
|
||||
def check_full_copies(overwrite: bool = False):
|
||||
diffs = []
|
||||
for target, source in FULL_COPIES.items():
|
||||
with open(source, "r", encoding="utf-8") as f:
|
||||
source_code = f.read()
|
||||
with open(target, "r", encoding="utf-8") as f:
|
||||
target_code = f.read()
|
||||
if source_code != target_code:
|
||||
if overwrite:
|
||||
with open(target, "w", encoding="utf-8") as f:
|
||||
print(f"Replacing the content of {target} by the one of {source}.")
|
||||
f.write(source_code)
|
||||
else:
|
||||
diffs.append(f"- {target}: copy does not match {source}.")
|
||||
|
||||
if not overwrite and len(diffs) > 0:
|
||||
diff = "\n".join(diffs)
|
||||
raise Exception(
|
||||
"Found the following copy inconsistencies:\n"
|
||||
+ diff
|
||||
+ "\nRun `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them."
|
||||
)
|
||||
|
||||
|
||||
def get_model_list():
|
||||
"""Extracts the model list from the README."""
|
||||
# If the introduction or the conclusion of the list change, the prompts may need to be updated.
|
||||
@ -324,3 +351,4 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
check_copies(args.fix_and_overwrite)
|
||||
check_full_copies(args.fix_and_overwrite)
|
||||
|
Loading…
Reference in New Issue
Block a user