mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Add possibility to maintain full copies of files (#12312)
This commit is contained in:
parent
9490d668d2
commit
57461ac0b4
@ -38,6 +38,7 @@ def postprocess_qa_predictions(
|
|||||||
null_score_diff_threshold: float = 0.0,
|
null_score_diff_threshold: float = 0.0,
|
||||||
output_dir: Optional[str] = None,
|
output_dir: Optional[str] = None,
|
||||||
prefix: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
|
is_world_process_zero: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the
|
Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the
|
||||||
@ -90,6 +91,7 @@ def postprocess_qa_predictions(
|
|||||||
scores_diff_json = collections.OrderedDict()
|
scores_diff_json = collections.OrderedDict()
|
||||||
|
|
||||||
# Logging.
|
# Logging.
|
||||||
|
logger.setLevel(logging.INFO if is_world_process_zero else logging.WARN)
|
||||||
logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
|
logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
|
||||||
|
|
||||||
# Let's loop over all the examples!
|
# Let's loop over all the examples!
|
||||||
|
@ -27,6 +27,9 @@ TRANSFORMERS_PATH = "src/transformers"
|
|||||||
PATH_TO_DOCS = "docs/source"
|
PATH_TO_DOCS = "docs/source"
|
||||||
REPO_PATH = "."
|
REPO_PATH = "."
|
||||||
|
|
||||||
|
# Mapping for files that are full copies of others (keys are copies, values the file to keep them up to data with)
|
||||||
|
FULL_COPIES = {"examples/tensorflow/question-answering/utils_qa.py": "examples/pytorch/question-answering/utils_qa.py"}
|
||||||
|
|
||||||
|
|
||||||
def _should_continue(line, indent):
|
def _should_continue(line, indent):
|
||||||
return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\):\s*$", line) is not None
|
return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\):\s*$", line) is not None
|
||||||
@ -192,6 +195,30 @@ def check_copies(overwrite: bool = False):
|
|||||||
check_model_list_copy(overwrite=overwrite)
|
check_model_list_copy(overwrite=overwrite)
|
||||||
|
|
||||||
|
|
||||||
|
def check_full_copies(overwrite: bool = False):
|
||||||
|
diffs = []
|
||||||
|
for target, source in FULL_COPIES.items():
|
||||||
|
with open(source, "r", encoding="utf-8") as f:
|
||||||
|
source_code = f.read()
|
||||||
|
with open(target, "r", encoding="utf-8") as f:
|
||||||
|
target_code = f.read()
|
||||||
|
if source_code != target_code:
|
||||||
|
if overwrite:
|
||||||
|
with open(target, "w", encoding="utf-8") as f:
|
||||||
|
print(f"Replacing the content of {target} by the one of {source}.")
|
||||||
|
f.write(source_code)
|
||||||
|
else:
|
||||||
|
diffs.append(f"- {target}: copy does not match {source}.")
|
||||||
|
|
||||||
|
if not overwrite and len(diffs) > 0:
|
||||||
|
diff = "\n".join(diffs)
|
||||||
|
raise Exception(
|
||||||
|
"Found the following copy inconsistencies:\n"
|
||||||
|
+ diff
|
||||||
|
+ "\nRun `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_model_list():
|
def get_model_list():
|
||||||
"""Extracts the model list from the README."""
|
"""Extracts the model list from the README."""
|
||||||
# If the introduction or the conclusion of the list change, the prompts may need to be updated.
|
# If the introduction or the conclusion of the list change, the prompts may need to be updated.
|
||||||
@ -324,3 +351,4 @@ if __name__ == "__main__":
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
check_copies(args.fix_and_overwrite)
|
check_copies(args.fix_and_overwrite)
|
||||||
|
check_full_copies(args.fix_and_overwrite)
|
||||||
|
Loading…
Reference in New Issue
Block a user