Fix ROUGE add example check and update README (#18398)

* Fix ROUGE add example check and update README

* Stay consistent in values
This commit is contained in:
Sylvain Gugger 2022-08-01 11:14:49 -04:00 committed by GitHub
parent 62098b9348
commit 941d233153
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 60 additions and 28 deletions

View File

@ -43,6 +43,22 @@ To browse the examples corresponding to released versions of 🤗 Transformers,
<details>
<summary>Examples for older versions of 🤗 Transformers</summary>
<ul>
<li><a href="https://github.com/huggingface/transformers/tree/v4.21.0/examples">v4.21.0</a></li>
<li><a href="https://github.com/huggingface/transformers/tree/v4.20.1/examples">v4.20.1</a></li>
<li><a href="https://github.com/huggingface/transformers/tree/v4.19.4/examples">v4.19.4</a></li>
<li><a href="https://github.com/huggingface/transformers/tree/v4.18.0/examples">v4.18.0</a></li>
<li><a href="https://github.com/huggingface/transformers/tree/v4.17.0/examples">v4.17.0</a></li>
<li><a href="https://github.com/huggingface/transformers/tree/v4.16.2/examples">v4.16.2</a></li>
<li><a href="https://github.com/huggingface/transformers/tree/v4.15.0/examples">v4.15.0</a></li>
<li><a href="https://github.com/huggingface/transformers/tree/v4.14.1/examples">v4.14.1</a></li>
<li><a href="https://github.com/huggingface/transformers/tree/v4.13.0/examples">v4.13.0</a></li>
<li><a href="https://github.com/huggingface/transformers/tree/v4.12.5/examples">v4.12.5</a></li>
<li><a href="https://github.com/huggingface/transformers/tree/v4.11.3/examples">v4.11.3</a></li>
<li><a href="https://github.com/huggingface/transformers/tree/v4.10.3/examples">v4.10.3</a></li>
<li><a href="https://github.com/huggingface/transformers/tree/v4.9.2/examples">v4.9.2</a></li>
<li><a href="https://github.com/huggingface/transformers/tree/v4.8.2/examples">v4.8.2</a></li>
<li><a href="https://github.com/huggingface/transformers/tree/v4.7.0/examples">v4.7.0</a></li>
<li><a href="https://github.com/huggingface/transformers/tree/v4.6.1/examples">v4.6.1</a></li>
<li><a href="https://github.com/huggingface/transformers/tree/v4.5.1/examples">v4.5.1</a></li>
<li><a href="https://github.com/huggingface/transformers/tree/v4.4.2/examples">v4.4.2</a></li>
<li><a href="https://github.com/huggingface/transformers/tree/v4.3.3/examples">v4.3.3</a></li>

View File

@ -3,3 +3,4 @@ jax>=0.2.8
jaxlib>=0.1.59
flax>=0.3.5
optax>=0.0.8
evaluate>=0.2.0

View File

@ -680,12 +680,9 @@ def main():
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
# Extract a few results from ROUGE
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
result = {k: round(v * 100, 4) for k, v in result.items()}
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
result["gen_len"] = np.mean(prediction_lens)
result = {k: round(v, 4) for k, v in result.items()}
return result
# Enable tensorboard only on the master node

View File

@ -23,4 +23,4 @@ torchvision
jiwer
librosa
torch < 1.12
evaluate
evaluate >= 0.2.0

View File

@ -48,10 +48,13 @@ from transformers import (
SchedulerType,
get_scheduler,
)
from transformers.utils import get_full_repo_name, send_example_telemetry
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.22.0.dev0")
logger = get_logger(__name__)
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")

View File

@ -52,10 +52,13 @@ from transformers import (
default_data_collator,
get_scheduler,
)
from transformers.utils import get_full_repo_name, send_example_telemetry
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.22.0.dev0")
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -52,10 +52,13 @@ from transformers import (
SchedulerType,
get_scheduler,
)
from transformers.utils import get_full_repo_name, send_example_telemetry
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.22.0.dev0")
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())

View File

@ -52,9 +52,12 @@ from transformers import (
default_data_collator,
get_scheduler,
)
from transformers.utils import PaddingStrategy, get_full_repo_name, send_example_telemetry
from transformers.utils import PaddingStrategy, check_min_version, get_full_repo_name, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.22.0.dev0")
logger = get_logger(__name__)
# You should update this to your particular problem to have better documentation of `model_type`
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())

View File

@ -45,10 +45,13 @@ from transformers import (
default_data_collator,
get_scheduler,
)
from transformers.utils import get_full_repo_name, send_example_telemetry
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.22.0.dev0")
logger = get_logger(__name__)
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")

View File

@ -625,12 +625,9 @@ def main():
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
# Extract a few results from ROUGE
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
result = {k: round(v * 100, 4) for k, v in result.items()}
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
result["gen_len"] = np.mean(prediction_lens)
result = {k: round(v, 4) for k, v in result.items()}
return result
# Initialize our Trainer

View File

@ -51,10 +51,13 @@ from transformers import (
SchedulerType,
get_scheduler,
)
from transformers.utils import get_full_repo_name, is_offline_mode, send_example_telemetry
from transformers.utils import check_min_version, get_full_repo_name, is_offline_mode, send_example_telemetry
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.22.0.dev0")
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
@ -707,10 +710,7 @@ def main():
references=decoded_labels,
)
result = metric.compute(use_stemmer=True)
# Extract a few results from ROUGE
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
result = {k: round(v, 4) for k, v in result.items()}
result = {k: round(v * 100, 4) for k, v in result.items()}
logger.info(result)

View File

@ -43,10 +43,13 @@ from transformers import (
default_data_collator,
get_scheduler,
)
from transformers.utils import get_full_repo_name, send_example_telemetry
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.22.0.dev0")
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

View File

@ -50,10 +50,13 @@ from transformers import (
default_data_collator,
get_scheduler,
)
from transformers.utils import get_full_repo_name, send_example_telemetry
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.22.0.dev0")
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")

View File

@ -52,10 +52,13 @@ from transformers import (
default_data_collator,
get_scheduler,
)
from transformers.utils import get_full_repo_name, send_example_telemetry
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.22.0.dev0")
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")

View File

@ -680,10 +680,7 @@ def main():
metric.add_batch(predictions=decoded_preds, references=decoded_labels)
result = metric.compute(use_stemmer=True)
# Extract a few results from ROUGE
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
result = {k: round(v, 4) for k, v in result.items()}
result = {k: round(v * 100, 4) for k, v in result.items()}
logger.info(result)
# endregion

View File

@ -105,7 +105,7 @@ _deps = [
"datasets",
"deepspeed>=0.6.5",
"dill<0.3.5",
"evaluate",
"evaluate>=0.2.0",
"fairscale>0.3",
"faiss-cpu",
"fastapi",

View File

@ -11,7 +11,7 @@ deps = {
"datasets": "datasets",
"deepspeed": "deepspeed>=0.6.5",
"dill": "dill<0.3.5",
"evaluate": "evaluate",
"evaluate": "evaluate>=0.2.0",
"fairscale": "fairscale>0.3",
"faiss-cpu": "faiss-cpu",
"fastapi": "fastapi",