Tests: detect lines removed from "utils/not_doctested.txt" and doctest ALL generation files (#25763)

This commit is contained in:
Joao Gante 2023-08-29 16:15:05 +01:00 committed by GitHub
parent 483861d52d
commit a35f889acc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 67 additions and 51 deletions

View File

@ -55,7 +55,7 @@ When you load a model explicitly, you can inspect the generation configuration t
>>> from transformers import AutoModelForCausalLM >>> from transformers import AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("distilgpt2") >>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
>>> model.generation_config >>> model.generation_config # doctest: +IGNORE_RESULT
GenerationConfig { GenerationConfig {
"_from_model_config": true, "_from_model_config": true,
"bos_token_id": 50256, "bos_token_id": 50256,
@ -77,7 +77,7 @@ producing highly repetitive results.
You can override any `generation_config` by passing the parameters and their values directly to the [`generate`] method: You can override any `generation_config` by passing the parameters and their values directly to the [`generate`] method:
```python ```python
>>> my_model.generate(**inputs, num_beams=4, do_sample=True) >>> my_model.generate(**inputs, num_beams=4, do_sample=True) # doctest: +SKIP
``` ```
Even if the default decoding strategy mostly works for your task, you can still tweak a few things. Some of the Even if the default decoding strategy mostly works for your task, you can still tweak a few things. Some of the
@ -107,11 +107,11 @@ If you would like to share your fine-tuned model with a specific generation conf
```python ```python
>>> from transformers import AutoModelForCausalLM, GenerationConfig >>> from transformers import AutoModelForCausalLM, GenerationConfig
>>> model = AutoModelForCausalLM.from_pretrained("my_account/my_model") >>> model = AutoModelForCausalLM.from_pretrained("my_account/my_model") # doctest: +SKIP
>>> generation_config = GenerationConfig( >>> generation_config = GenerationConfig(
... max_new_tokens=50, do_sample=True, top_k=50, eos_token_id=model.config.eos_token_id ... max_new_tokens=50, do_sample=True, top_k=50, eos_token_id=model.config.eos_token_id
... ) ... )
>>> generation_config.save_pretrained("my_account/my_model", push_to_hub=True) >>> generation_config.save_pretrained("my_account/my_model", push_to_hub=True) # doctest: +SKIP
``` ```
You can also store several generation configurations in a single directory, making use of the `config_file_name` You can also store several generation configurations in a single directory, making use of the `config_file_name`
@ -133,14 +133,15 @@ one for summarization with beam search). You must have the right Hub permissions
... pad_token=model.config.pad_token_id, ... pad_token=model.config.pad_token_id,
... ) ... )
>>> translation_generation_config.save_pretrained("t5-small", "translation_generation_config.json", push_to_hub=True) >>> # Tip: add `push_to_hub=True` to push to the Hub
>>> translation_generation_config.save_pretrained("/tmp", "translation_generation_config.json")
>>> # You could then use the named generation config file to parameterize generation >>> # You could then use the named generation config file to parameterize generation
>>> generation_config = GenerationConfig.from_pretrained("t5-small", "translation_generation_config.json") >>> generation_config = GenerationConfig.from_pretrained("/tmp", "translation_generation_config.json")
>>> inputs = tokenizer("translate English to French: Configuration files are easy to use!", return_tensors="pt") >>> inputs = tokenizer("translate English to French: Configuration files are easy to use!", return_tensors="pt")
>>> outputs = model.generate(**inputs, generation_config=generation_config) >>> outputs = model.generate(**inputs, generation_config=generation_config)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)) >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
['Les fichiers de configuration sont faciles à utiliser !'] ['Les fichiers de configuration sont faciles à utiliser!']
``` ```
## Streaming ## Streaming
@ -217,10 +218,9 @@ The two main parameters that enable and control the behavior of contrastive sear
>>> outputs = model.generate(**inputs, penalty_alpha=0.6, top_k=4, max_new_tokens=100) >>> outputs = model.generate(**inputs, penalty_alpha=0.6, top_k=4, max_new_tokens=100)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Hugging Face Company is a family owned and operated business. \ ['Hugging Face Company is a family owned and operated business. We pride ourselves on being the best
We pride ourselves on being the best in the business and our customer service is second to none.\ in the business and our customer service is second to none.\n\nIf you have any questions about our
\n\nIf you have any questions about our products or services, feel free to contact us at any time.\ products or services, feel free to contact us at any time. We look forward to hearing from you!']
We look forward to hearing from you!']
``` ```
### Multinomial sampling ### Multinomial sampling
@ -233,7 +233,8 @@ risk of repetition.
To enable multinomial sampling set `do_sample=True` and `num_beams=1`. To enable multinomial sampling set `do_sample=True` and `num_beams=1`.
```python ```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM >>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
>>> set_seed(0) # For reproducibility
>>> checkpoint = "gpt2-large" >>> checkpoint = "gpt2-large"
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint) >>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
@ -244,11 +245,8 @@ To enable multinomial sampling set `do_sample=True` and `num_beams=1`.
>>> outputs = model.generate(**inputs, do_sample=True, num_beams=1, max_new_tokens=100) >>> outputs = model.generate(**inputs, do_sample=True, num_beams=1, max_new_tokens=100)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Today was an amazing day because we are now in the final stages of our trip to New York City which was very tough. \ ['Today was an amazing day because when you go to the World Cup and you don\'t, or when you don\'t get invited,
It is a difficult schedule and a challenging part of the year but still worth it. I have been taking things easier and \ that\'s a terrible feeling."']
I feel stronger and more motivated to be out there on their tour. Hopefully, that experience is going to help them with \
their upcoming events which are currently scheduled in Australia.\n\nWe love that they are here. They want to make a \
name for themselves and become famous for what they']
``` ```
### Beam-search decoding ### Beam-search decoding
@ -272,7 +270,7 @@ To enable this decoding strategy, specify the `num_beams` (aka number of hypothe
>>> outputs = model.generate(**inputs, num_beams=5, max_new_tokens=50) >>> outputs = model.generate(**inputs, num_beams=5, max_new_tokens=50)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['It is astonishing how one can have such a profound impact on the lives of so many people in such a short period of \ ['It is astonishing how one can have such a profound impact on the lives of so many people in such a short period of
time."\n\nHe added: "I am very proud of the work I have been able to do in the last few years.\n\n"I have'] time."\n\nHe added: "I am very proud of the work I have been able to do in the last few years.\n\n"I have']
``` ```
@ -282,7 +280,8 @@ As the name implies, this decoding strategy combines beam search with multinomia
the `num_beams` greater than 1, and set `do_sample=True` to use this decoding strategy. the `num_beams` greater than 1, and set `do_sample=True` to use this decoding strategy.
```python ```python
>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, set_seed
>>> set_seed(0) # For reproducibility
>>> prompt = "translate English to German: The house is wonderful." >>> prompt = "translate English to German: The house is wonderful."
>>> checkpoint = "t5-small" >>> checkpoint = "t5-small"
@ -309,20 +308,22 @@ The diversily penalty ensures the outputs are distinct across groups, and beam s
>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
>>> checkpoint = "google/pegasus-xsum" >>> checkpoint = "google/pegasus-xsum"
>>> prompt = "The Permaculture Design Principles are a set of universal design principles \ >>> prompt = (
>>> that can be applied to any location, climate and culture, and they allow us to design \ ... "The Permaculture Design Principles are a set of universal design principles "
>>> the most efficient and sustainable human habitation and food production systems. \ ... "that can be applied to any location, climate and culture, and they allow us to design "
>>> Permaculture is a design system that encompasses a wide variety of disciplines, such \ ... "the most efficient and sustainable human habitation and food production systems. "
>>> as ecology, landscape design, environmental science and energy conservation, and the \ ... "Permaculture is a design system that encompasses a wide variety of disciplines, such "
>>> Permaculture design principles are drawn from these various disciplines. Each individual \ ... "as ecology, landscape design, environmental science and energy conservation, and the "
>>> design principle itself embodies a complete conceptual framework based on sound \ ... "Permaculture design principles are drawn from these various disciplines. Each individual "
>>> scientific principles. When we bring all these separate principles together, we can \ ... "design principle itself embodies a complete conceptual framework based on sound "
>>> create a design system that both looks at whole systems, the parts that these systems \ ... "scientific principles. When we bring all these separate principles together, we can "
>>> consist of, and how those parts interact with each other to create a complex, dynamic, \ ... "create a design system that both looks at whole systems, the parts that these systems "
>>> living system. Each design principle serves as a tool that allows us to integrate all \ ... "consist of, and how those parts interact with each other to create a complex, dynamic, "
>>> the separate parts of a design, referred to as elements, into a functional, synergistic, \ ... "living system. Each design principle serves as a tool that allows us to integrate all "
>>> whole system, where the elements harmoniously interact and work together in the most \ ... "the separate parts of a design, referred to as elements, into a functional, synergistic, "
>>> efficient way possible." ... "whole system, where the elements harmoniously interact and work together in the most "
... "efficient way possible."
... )
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint) >>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
>>> inputs = tokenizer(prompt, return_tensors="pt") >>> inputs = tokenizer(prompt, return_tensors="pt")
@ -369,7 +370,8 @@ When using assisted decoding with sampling methods, you can use the `temperarure
just like in multinomial sampling. However, in assisted decoding, reducing the temperature will help improving latency. just like in multinomial sampling. However, in assisted decoding, reducing the temperature will help improving latency.
```python ```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer >>> from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
>>> set_seed(42) # For reproducibility
>>> prompt = "Alice and Bob" >>> prompt = "Alice and Bob"
>>> checkpoint = "EleutherAI/pythia-1.4b-deduped" >>> checkpoint = "EleutherAI/pythia-1.4b-deduped"
@ -382,5 +384,5 @@ just like in multinomial sampling. However, in assisted decoding, reducing the t
>>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint) >>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint)
>>> outputs = model.generate(**inputs, assistant_model=assistant_model, do_sample=True, temperature=0.5) >>> outputs = model.generate(**inputs, assistant_model=assistant_model, do_sample=True, temperature=0.5)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
["Alice and Bob are sitting on the sofa. Alice says, 'I'm going to my room"] ['Alice and Bob are going to the same party. It is a small party, in a small']
``` ```

View File

@ -14,14 +14,12 @@ docs/source/en/custom_models.md
docs/source/en/custom_tools.md docs/source/en/custom_tools.md
docs/source/en/debugging.md docs/source/en/debugging.md
docs/source/en/fast_tokenizers.md docs/source/en/fast_tokenizers.md
docs/source/en/generation_strategies.md
docs/source/en/glossary.md docs/source/en/glossary.md
docs/source/en/hpo_train.md docs/source/en/hpo_train.md
docs/source/en/index.md docs/source/en/index.md
docs/source/en/installation.md docs/source/en/installation.md
docs/source/en/internal/audio_utils.md docs/source/en/internal/audio_utils.md
docs/source/en/internal/file_utils.md docs/source/en/internal/file_utils.md
docs/source/en/internal/generation_utils.md
docs/source/en/internal/image_processing_utils.md docs/source/en/internal/image_processing_utils.md
docs/source/en/internal/modeling_utils.md docs/source/en/internal/modeling_utils.md
docs/source/en/internal/pipelines_utils.md docs/source/en/internal/pipelines_utils.md
@ -45,7 +43,6 @@ docs/source/en/main_classes/output.md
docs/source/en/main_classes/pipelines.md docs/source/en/main_classes/pipelines.md
docs/source/en/main_classes/processors.md docs/source/en/main_classes/processors.md
docs/source/en/main_classes/quantization.md docs/source/en/main_classes/quantization.md
docs/source/en/main_classes/text_generation.md
docs/source/en/main_classes/tokenizer.md docs/source/en/main_classes/tokenizer.md
docs/source/en/main_classes/trainer.md docs/source/en/main_classes/trainer.md
docs/source/en/model_doc/albert.md docs/source/en/model_doc/albert.md
@ -367,16 +364,6 @@ src/transformers/dynamic_module_utils.py
src/transformers/feature_extraction_sequence_utils.py src/transformers/feature_extraction_sequence_utils.py
src/transformers/feature_extraction_utils.py src/transformers/feature_extraction_utils.py
src/transformers/file_utils.py src/transformers/file_utils.py
src/transformers/generation/beam_constraints.py
src/transformers/generation/beam_search.py
src/transformers/generation/flax_logits_process.py
src/transformers/generation/flax_utils.py
src/transformers/generation/stopping_criteria.py
src/transformers/generation/streamers.py
src/transformers/generation/tf_logits_process.py
src/transformers/generation_flax_utils.py
src/transformers/generation_tf_utils.py
src/transformers/generation_utils.py
src/transformers/hf_argparser.py src/transformers/hf_argparser.py
src/transformers/hyperparameter_search.py src/transformers/hyperparameter_search.py
src/transformers/image_processing_utils.py src/transformers/image_processing_utils.py
@ -413,7 +400,6 @@ src/transformers/models/auto/modeling_tf_auto.py
src/transformers/models/autoformer/configuration_autoformer.py src/transformers/models/autoformer/configuration_autoformer.py
src/transformers/models/autoformer/modeling_autoformer.py src/transformers/models/autoformer/modeling_autoformer.py
src/transformers/models/bark/convert_suno_to_hf.py src/transformers/models/bark/convert_suno_to_hf.py
src/transformers/models/bark/generation_configuration_bark.py
src/transformers/models/bart/convert_bart_original_pytorch_checkpoint_to_pytorch.py src/transformers/models/bart/convert_bart_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/bart/modeling_flax_bart.py src/transformers/models/bart/modeling_flax_bart.py
src/transformers/models/bart/modeling_tf_bart.py src/transformers/models/bart/modeling_tf_bart.py
@ -925,9 +911,7 @@ src/transformers/pipelines/object_detection.py
src/transformers/pipelines/pt_utils.py src/transformers/pipelines/pt_utils.py
src/transformers/pipelines/question_answering.py src/transformers/pipelines/question_answering.py
src/transformers/pipelines/table_question_answering.py src/transformers/pipelines/table_question_answering.py
src/transformers/pipelines/text2text_generation.py
src/transformers/pipelines/text_classification.py src/transformers/pipelines/text_classification.py
src/transformers/pipelines/text_generation.py
src/transformers/pipelines/token_classification.py src/transformers/pipelines/token_classification.py
src/transformers/pipelines/video_classification.py src/transformers/pipelines/video_classification.py
src/transformers/pipelines/visual_question_answering.py src/transformers/pipelines/visual_question_answering.py

View File

@ -1 +1,2 @@
docs/source/en/generation_strategies.md
docs/source/en/task_summary.md docs/source/en/task_summary.md

View File

@ -395,6 +395,31 @@ def get_all_doctest_files() -> List[str]:
return sorted(test_files_to_run) return sorted(test_files_to_run)
def get_new_doctest_files(repo, base_commit, branching_commit) -> List[str]:
"""
Get the list of files that were removed from "utils/not_doctested.txt", between `base_commit` and
`branching_commit`.
Returns:
`List[str]`: List of files that were removed from "utils/not_doctested.txt".
"""
for diff_obj in branching_commit.diff(base_commit):
# Ignores all but the "utils/not_doctested.txt" file.
if diff_obj.a_path != "utils/not_doctested.txt":
continue
# Loads the two versions
folder = Path(repo.working_dir)
with checkout_commit(repo, branching_commit):
with open(folder / "utils/not_doctested.txt", "r", encoding="utf-8") as f:
old_content = f.read()
with open(folder / "utils/not_doctested.txt", "r", encoding="utf-8") as f:
new_content = f.read()
# Compute the removed lines and return them
removed_content = set(old_content.split("\n")) - set(new_content.split("\n"))
return sorted(removed_content)
return []
def get_doctest_files(diff_with_last_commit: bool = False) -> List[str]: def get_doctest_files(diff_with_last_commit: bool = False) -> List[str]:
""" """
Return a list of python and Markdown files where doc example have been modified between: Return a list of python and Markdown files where doc example have been modified between:
@ -426,6 +451,10 @@ def get_doctest_files(diff_with_last_commit: bool = False) -> List[str]:
all_test_files_to_run = get_all_doctest_files() all_test_files_to_run = get_all_doctest_files()
# Add to the test files to run any removed entry from "utils/not_doctested.txt".
new_test_files = get_new_doctest_files(repo, repo.head.commit, repo.refs.main.commit)
test_files_to_run = list(set(test_files_to_run + new_test_files))
# Do not run slow doctest tests on CircleCI # Do not run slow doctest tests on CircleCI
with open("utils/slow_documentation_tests.txt") as fp: with open("utils/slow_documentation_tests.txt") as fp:
slow_documentation_tests = set(fp.read().strip().split("\n")) slow_documentation_tests = set(fp.read().strip().split("\n"))