transformers/utils/check_pipeline_typing.py
Pavel Iakubovskii b3b7789cbc
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
New model PR merged notification / Notify new model (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run
Better pipeline type hints (#38049)
* image-classification

* depth-estimation

* zero-shot-image-classification

* image-feature-extraction

* image-segmentation

* mask-generation

* object-detection

* zero-shot-object-detection

* image-to-image

* image-text-to-text

* image-to-text

* text-classification

* text-generation

* text-to-audio

* text2text_generation

* fixup

* token-classification

* document-qa

* video-classification

* audio-classification

* automatic-speech-recognition

* feature-extraction

* fill-mask

* zero-shot-audio-classification

* Add pipeline function typing

* Add code generator and checker for pipeline types

* Add to makefile

* style

* Add to CI

* Style
2025-06-13 13:44:07 +01:00

94 lines
4.6 KiB
Python

import re
from transformers.pipelines import SUPPORTED_TASKS, Pipeline
HEADER = """
# fmt: off
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# The part of the file below was automatically generated from the code.
# Do NOT edit this part of the file manually as any edits will be overwritten by the generation
# of the file. If any change should be done, please apply the changes to the `pipeline` function
# below and run `python utils/check_pipeline_typing.py --fix_and_overwrite` to update the file.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from typing import Literal, overload
"""
FOOTER = """
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# The part of the file above was automatically generated from the code.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# fmt: on
"""
TASK_PATTERN = "task: Optional[str] = None"
def main(pipeline_file_path: str, fix_and_overwrite: bool = False):
with open(pipeline_file_path, "r") as file:
content = file.read()
# extract generated code in between <generated-code> and </generated-code>
current_generated_code = re.search(r"# <generated-code>(.*)# </generated-code>", content, re.DOTALL).group(1)
content_without_generated_code = content.replace(current_generated_code, "")
# extract pipeline signature in between `def pipeline` and `-> Pipeline`
pipeline_signature = re.search(r"def pipeline(.*) -> Pipeline:", content_without_generated_code, re.DOTALL).group(
1
)
pipeline_signature = pipeline_signature.replace("(\n ", "(") # start of the signature
pipeline_signature = pipeline_signature.replace(",\n ", ", ") # intermediate arguments
pipeline_signature = pipeline_signature.replace(",\n)", ")") # end of the signature
# collect and sort available pipelines
pipelines = [(f'"{task}"', task_info["impl"]) for task, task_info in SUPPORTED_TASKS.items()]
pipelines = sorted(pipelines, key=lambda x: x[0])
pipelines.insert(0, (None, Pipeline))
# generate new `pipeline` signatures
new_generated_code = ""
for task, pipeline_class in pipelines:
if TASK_PATTERN not in pipeline_signature:
raise ValueError(f"Can't find `{TASK_PATTERN}` in pipeline signature: {pipeline_signature}")
pipeline_type = pipeline_class if isinstance(pipeline_class, str) else pipeline_class.__name__
new_pipeline_signature = pipeline_signature.replace(TASK_PATTERN, f"task: Literal[{task}]")
new_generated_code += f"@overload\ndef pipeline{new_pipeline_signature} -> {pipeline_type}: ...\n"
new_generated_code = HEADER + new_generated_code + FOOTER
new_generated_code = new_generated_code.rstrip("\n") + "\n"
if new_generated_code != current_generated_code and fix_and_overwrite:
print(f"Updating {pipeline_file_path}...")
wrapped_current_generated_code = "# <generated-code>" + current_generated_code + "# </generated-code>"
wrapped_new_generated_code = "# <generated-code>" + new_generated_code + "# </generated-code>"
content = content.replace(wrapped_current_generated_code, wrapped_new_generated_code)
# write content to file
with open(pipeline_file_path, "w") as file:
file.write(content)
elif new_generated_code != current_generated_code and not fix_and_overwrite:
message = (
f"Found inconsistencies in {pipeline_file_path}. "
"Run `python utils/check_pipeline_typing.py --fix_and_overwrite` to fix them."
)
raise ValueError(message)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
parser.add_argument(
"--pipeline_file_path",
type=str,
default="src/transformers/pipelines/__init__.py",
help="Path to the pipeline file.",
)
args = parser.parse_args()
main(args.pipeline_file_path, args.fix_and_overwrite)