More fixes for doctest (#30265)

* fix

* update

* update

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2024-04-16 11:58:55 +02:00 committed by GitHub
parent 51bcadc10a
commit cbc2cc187a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 65 additions and 32 deletions

View File

@ -52,6 +52,9 @@ RUN python3 -m pip install --no-cache-dir natten==0.15.1+torch220$CUDA -f https:
# For `nougat` tokenizer
RUN python3 -m pip install --no-cache-dir python-Levenshtein
# For `FastSpeech2ConformerTokenizer` tokenizer
RUN python3 -m pip install --no-cache-dir g2p-en
# When installing in editable mode, `transformers` is not recognized as a package.
# this line must be added in order for python to be aware of transformers.
RUN cd transformers && python3 setup.py develop

View File

@ -97,6 +97,7 @@ If you only want the infilled part:
>>> generator = pipeline("text-generation",model="codellama/CodeLlama-7b-hf",torch_dtype=torch.float16, device_map="auto")
>>> generator('def remove_non_ascii(s: str) -> str:\n """ <FILL_ME>\n return result', max_new_tokens = 128)
[{'generated_text': 'def remove_non_ascii(s: str) -> str:\n """ <FILL_ME>\n return resultRemove non-ASCII characters from a string. """\n result = ""\n for c in s:\n if ord(c) < 128:\n result += c'}]
```
Under the hood, the tokenizer [automatically splits by `<FILL_ME>`](https://huggingface.co/docs/transformers/main/model_doc/code_llama#transformers.CodeLlamaTokenizer.fill_token) to create a formatted input string that follows [the original training pattern](https://github.com/facebookresearch/codellama/blob/cb51c14ec761370ba2e2bc351374a79265d0465e/llama/generation.py#L402). This is more robust than preparing the pattern yourself: it avoids pitfalls, such as token glueing, that are very hard to debug. To see how much CPU and GPU memory you need for this model or others, try [this calculator](https://huggingface.co/spaces/hf-accelerate/model-memory-usage) which can help determine that value.

View File

@ -270,11 +270,13 @@ For example, if you use this [invoice image](https://huggingface.co/spaces/impir
>>> from transformers import pipeline
>>> vqa = pipeline(model="impira/layoutlm-document-qa")
>>> vqa(
>>> output = vqa(
... image="https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png",
... question="What is the invoice number?",
... )
[{'score': 0.42515, 'answer': 'us-001', 'start': 16, 'end': 16}]
>>> output[0]["score"] = round(output[0]["score"], 3)
>>> output
[{'score': 0.425, 'answer': 'us-001', 'start': 16, 'end': 16}]
```
<Tip>

View File

@ -326,7 +326,7 @@ Document question answering is a task that answers natural language questions fr
>>> from PIL import Image
>>> import requests
>>> url = "https://datasets-server.huggingface.co/assets/hf-internal-testing/example-documents/--/hf-internal-testing--example-documents/test/2/image/image.jpg"
>>> url = "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/2.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> doc_question_answerer = pipeline("document-question-answering", model="magorshunov/layoutlm-invoices")

View File

@ -325,7 +325,7 @@ Las respuestas a preguntas de documentos es una tarea que responde preguntas en
>>> from PIL import Image
>>> import requests
>>> url = "https://datasets-server.huggingface.co/assets/hf-internal-testing/example-documents/--/hf-internal-testing--example-documents/test/2/image/image.jpg"
>>> url = "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/2.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> doc_question_answerer = pipeline("document-question-answering", model="magorshunov/layoutlm-invoices")

View File

@ -270,11 +270,13 @@ NLP कार्यों के लिए [`pipeline`] का उपयोग
>>> from transformers import pipeline
>>> vqa = pipeline(model="impira/layoutlm-document-qa")
>>> vqa(
>>> output = vqa(
... image="https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png",
... question="What is the invoice number?",
... )
[{'score': 0.42515, 'answer': 'us-001', 'start': 16, 'end': 16}]
>>> output[0]["score"] = round(output[0]["score"], 3)
>>> output
[{'score': 0.425, 'answer': 'us-001', 'start': 16, 'end': 16}]
```
<Tip>

View File

@ -95,6 +95,7 @@ def remove_non_ascii(s: str) -> str:
>>> generator = pipeline("text-generation",model="codellama/CodeLlama-7b-hf",torch_dtype=torch.float16, device_map="auto")
>>> generator('def remove_non_ascii(s: str) -> str:\n """ <FILL_ME>\n return result', max_new_tokens = 128)
[{'generated_text': 'def remove_non_ascii(s: str) -> str:\n """ <FILL_ME>\n return resultRemove non-ASCII characters from a string. """\n result = ""\n for c in s:\n if ord(c) < 128:\n result += c'}]
```
内部では、トークナイザーが [`<FILL_ME>` によって自動的に分割](https://huggingface.co/docs/transformers/main/model_doc/code_llama#transformers.CodeLlamaTokenizer.fill_token) して、[ に続く書式設定された入力文字列を作成します。オリジナルのトレーニング パターン](https://github.com/facebookresearch/codellama/blob/cb51c14ec761370ba2e2bc351374a79265d0465e/llama/generation.py#L402)。これは、パターンを自分で準備するよりも堅牢です。トークンの接着など、デバッグが非常に難しい落とし穴を回避できます。このモデルまたは他のモデルに必要な CPU および GPU メモリの量を確認するには、その値を決定するのに役立つ [この計算ツール](https://huggingface.co/spaces/hf-accelerate/model-memory-usage) を試してください。

View File

@ -246,11 +246,13 @@ for out in pipe(KeyDataset(dataset, "audio")):
>>> from transformers import pipeline
>>> vqa = pipeline(model="impira/layoutlm-document-qa")
>>> vqa(
>>> output = vqa(
... image="https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png",
... question="What is the invoice number?",
... )
[{'score': 0.42515, 'answer': 'us-001', 'start': 16, 'end': 16}]
>>> output[0]["score"] = round(output[0]["score"], 3)
>>> output
[{'score': 0.425, 'answer': 'us-001', 'start': 16, 'end': 16}]
```
<Tip>

View File

@ -340,7 +340,7 @@ score: 0.9327, start: 30, end: 54, answer: huggingface/transformers
>>> from PIL import Image
>>> import requests
>>> url = "https://datasets-server.huggingface.co/assets/hf-internal-testing/example-documents/--/hf-internal-testing--example-documents/test/2/image/image.jpg"
>>> url = "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/2.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> doc_question_answerer = pipeline("document-question-answering", model="magorshunov/layoutlm-invoices")

View File

@ -257,11 +257,13 @@ for out in pipe(KeyDataset(dataset, "audio")):
>>> from transformers import pipeline
>>> vqa = pipeline(model="impira/layoutlm-document-qa")
>>> vqa(
>>> output = vqa(
... image="https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png",
... question="What is the invoice number?",
... )
[{'score': 0.42515, 'answer': 'us-001', 'start': 16, 'end': 16}]
>>> output[0]["score"] = round(output[0]["score"], 3)
>>> output
[{'score': 0.425, 'answer': 'us-001', 'start': 16, 'end': 16}]
```
<Tip>

View File

@ -332,7 +332,7 @@ score: 0.9327, start: 30, end: 54, answer: huggingface/transformers
>>> from PIL import Image
>>> import requests
>>> url = "https://datasets-server.huggingface.co/assets/hf-internal-testing/example-documents/--/hf-internal-testing--example-documents/test/2/image/image.jpg"
>>> url = "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/2.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> doc_question_answerer = pipeline("document-question-answering", model="magorshunov/layoutlm-invoices")

View File

@ -1719,7 +1719,7 @@ class ClapAudioModel(ClapPreTrainedModel):
>>> from datasets import load_dataset
>>> from transformers import AutoProcessor, ClapAudioModel
>>> dataset = load_dataset("ashraq/esc50")
>>> dataset = load_dataset("hf-internal-testing/ashraq-esc50-1-dog-example")
>>> audio_sample = dataset["train"]["audio"][0]["array"]
>>> model = ClapAudioModel.from_pretrained("laion/clap-htsat-fused")
@ -2067,7 +2067,7 @@ class ClapModel(ClapPreTrainedModel):
>>> from datasets import load_dataset
>>> from transformers import AutoProcessor, ClapModel
>>> dataset = load_dataset("ashraq/esc50")
>>> dataset = load_dataset("hf-internal-testing/ashraq-esc50-1-dog-example")
>>> audio_sample = dataset["train"]["audio"][0]["array"]
>>> model = ClapModel.from_pretrained("laion/clap-htsat-unfused")
@ -2260,7 +2260,7 @@ class ClapAudioModelWithProjection(ClapPreTrainedModel):
>>> model = ClapAudioModelWithProjection.from_pretrained("laion/clap-htsat-fused")
>>> processor = ClapProcessor.from_pretrained("laion/clap-htsat-fused")
>>> dataset = load_dataset("ashraq/esc50")
>>> dataset = load_dataset("hf-internal-testing/ashraq-esc50-1-dog-example")
>>> audio_sample = dataset["train"]["audio"][0]["array"]
>>> inputs = processor(audios=audio_sample, return_tensors="pt")

View File

@ -776,7 +776,7 @@ class EncodecModel(EncodecPreTrainedModel):
>>> from datasets import load_dataset
>>> from transformers import AutoProcessor, EncodecModel
>>> dataset = load_dataset("ashraq/esc50")
>>> dataset = load_dataset("hf-internal-testing/ashraq-esc50-1-dog-example")
>>> audio_sample = dataset["train"]["audio"][0]["array"]
>>> model_id = "facebook/encodec_24khz"

View File

@ -2996,10 +2996,10 @@ class GroundingDinoForObjectDetection(GroundingDinoPreTrainedModel):
... outputs, threshold=0.35, target_sizes=target_sizes
... )[0]
>>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
... box = [round(i, 2) for i in box.tolist()]
... print(f"Detected {label.item()} with confidence " f"{round(score.item(), 3)} at location {box}")
Detected 1 with confidence 0.453 at location [344.82, 23.18, 637.4, 373.83]
Detected 1 with confidence 0.408 at location [11.92, 51.58, 316.57, 472.89]
... box = [round(i, 1) for i in box.tolist()]
... print(f"Detected {label.item()} with confidence " f"{round(score.item(), 2)} at location {box}")
Detected 1 with confidence 0.45 at location [344.8, 23.2, 637.4, 373.8]
Detected 1 with confidence 0.41 at location [11.9, 51.6, 316.6, 472.9]
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

View File

@ -33,7 +33,6 @@ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, Ima
from ...modeling_utils import PreTrainedModel
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
@ -48,10 +47,6 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "SiglipConfig"
_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "google/siglip-base-patch16-224"
_IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_1"
from ..deprecated._archive_maps import SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST # noqa: F401, E402
@ -1218,12 +1213,7 @@ class SiglipForImageClassification(SiglipPreTrainedModel):
self.post_init()
@add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=ImageClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
@replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
@ -1237,7 +1227,34 @@ class SiglipForImageClassification(SiglipPreTrainedModel):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
Returns:
Examples:
```python
>>> from transformers import AutoImageProcessor, SiglipForImageClassification
>>> import torch
>>> from PIL import Image
>>> import requests
>>> torch.manual_seed(3) # doctest: +IGNORE_RESULT
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> # note: we are loading a `SiglipModel` from the hub here,
>>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above.
>>> image_processor = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224")
>>> model = SiglipForImageClassification.from_pretrained("google/siglip-base-patch16-224")
>>> inputs = image_processor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> # model predicts one of the two classes
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
Predicted class: LABEL_0
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states

View File

@ -165,6 +165,7 @@ docs/source/en/model_doc/megatron-bert.md
docs/source/en/model_doc/megatron_gpt2.md
docs/source/en/model_doc/mgp-str.md
docs/source/en/model_doc/mistral.md
docs/source/en/model_doc/mixtral.md
docs/source/en/model_doc/mluke.md
docs/source/en/model_doc/mms.md
docs/source/en/model_doc/mobilebert.md

View File

@ -506,6 +506,8 @@ def get_all_doctest_files() -> List[str]:
test_files_to_run = py_files + md_files
# change to use "/" as path separator
test_files_to_run = ["/".join(Path(x).parts) for x in test_files_to_run]
# don't run doctest for files in `src/transformers/models/deprecated`
test_files_to_run = [x for x in test_files_to_run if "models/deprecated" not in test_files_to_run]
# only include files in `src` or `docs/source/en/`
test_files_to_run = [x for x in test_files_to_run if x.startswith(("src/", "docs/source/en/"))]