mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00

Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
New model PR merged notification / Notify new model (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run
* image-classification * depth-estimation * zero-shot-image-classification * image-feature-extraction * image-segmentation * mask-generation * object-detection * zero-shot-object-detection * image-to-image * image-text-to-text * image-to-text * text-classification * text-generation * text-to-audio * text2text_generation * fixup * token-classification * document-qa * video-classification * audio-classification * automatic-speech-recognition * feature-extraction * fill-mask * zero-shot-audio-classification * Add pipeline function typing * Add code generator and checker for pipeline types * Add to makefile * style * Add to CI * Style
94 lines
4.6 KiB
Python
94 lines
4.6 KiB
Python
import re
|
|
|
|
from transformers.pipelines import SUPPORTED_TASKS, Pipeline
|
|
|
|
|
|
HEADER = """
|
|
# fmt: off
|
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
# The part of the file below was automatically generated from the code.
|
|
# Do NOT edit this part of the file manually as any edits will be overwritten by the generation
|
|
# of the file. If any change should be done, please apply the changes to the `pipeline` function
|
|
# below and run `python utils/check_pipeline_typing.py --fix_and_overwrite` to update the file.
|
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
|
|
from typing import Literal, overload
|
|
|
|
|
|
"""
|
|
|
|
FOOTER = """
|
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
# The part of the file above was automatically generated from the code.
|
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
# fmt: on
|
|
"""
|
|
|
|
TASK_PATTERN = "task: Optional[str] = None"
|
|
|
|
|
|
def main(pipeline_file_path: str, fix_and_overwrite: bool = False):
|
|
with open(pipeline_file_path, "r") as file:
|
|
content = file.read()
|
|
|
|
# extract generated code in between <generated-code> and </generated-code>
|
|
current_generated_code = re.search(r"# <generated-code>(.*)# </generated-code>", content, re.DOTALL).group(1)
|
|
content_without_generated_code = content.replace(current_generated_code, "")
|
|
|
|
# extract pipeline signature in between `def pipeline` and `-> Pipeline`
|
|
pipeline_signature = re.search(r"def pipeline(.*) -> Pipeline:", content_without_generated_code, re.DOTALL).group(
|
|
1
|
|
)
|
|
pipeline_signature = pipeline_signature.replace("(\n ", "(") # start of the signature
|
|
pipeline_signature = pipeline_signature.replace(",\n ", ", ") # intermediate arguments
|
|
pipeline_signature = pipeline_signature.replace(",\n)", ")") # end of the signature
|
|
|
|
# collect and sort available pipelines
|
|
pipelines = [(f'"{task}"', task_info["impl"]) for task, task_info in SUPPORTED_TASKS.items()]
|
|
pipelines = sorted(pipelines, key=lambda x: x[0])
|
|
pipelines.insert(0, (None, Pipeline))
|
|
|
|
# generate new `pipeline` signatures
|
|
new_generated_code = ""
|
|
for task, pipeline_class in pipelines:
|
|
if TASK_PATTERN not in pipeline_signature:
|
|
raise ValueError(f"Can't find `{TASK_PATTERN}` in pipeline signature: {pipeline_signature}")
|
|
pipeline_type = pipeline_class if isinstance(pipeline_class, str) else pipeline_class.__name__
|
|
new_pipeline_signature = pipeline_signature.replace(TASK_PATTERN, f"task: Literal[{task}]")
|
|
new_generated_code += f"@overload\ndef pipeline{new_pipeline_signature} -> {pipeline_type}: ...\n"
|
|
|
|
new_generated_code = HEADER + new_generated_code + FOOTER
|
|
new_generated_code = new_generated_code.rstrip("\n") + "\n"
|
|
|
|
if new_generated_code != current_generated_code and fix_and_overwrite:
|
|
print(f"Updating {pipeline_file_path}...")
|
|
wrapped_current_generated_code = "# <generated-code>" + current_generated_code + "# </generated-code>"
|
|
wrapped_new_generated_code = "# <generated-code>" + new_generated_code + "# </generated-code>"
|
|
content = content.replace(wrapped_current_generated_code, wrapped_new_generated_code)
|
|
|
|
# write content to file
|
|
with open(pipeline_file_path, "w") as file:
|
|
file.write(content)
|
|
|
|
elif new_generated_code != current_generated_code and not fix_and_overwrite:
|
|
message = (
|
|
f"Found inconsistencies in {pipeline_file_path}. "
|
|
"Run `python utils/check_pipeline_typing.py --fix_and_overwrite` to fix them."
|
|
)
|
|
raise ValueError(message)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
|
|
parser.add_argument(
|
|
"--pipeline_file_path",
|
|
type=str,
|
|
default="src/transformers/pipelines/__init__.py",
|
|
help="Path to the pipeline file.",
|
|
)
|
|
args = parser.parse_args()
|
|
main(args.pipeline_file_path, args.fix_and_overwrite)
|