mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix ROUGE add example check and update README (#18398)
* Fix ROUGE add example check and update README * Stay consistent in values
This commit is contained in:
parent
62098b9348
commit
941d233153
@ -43,6 +43,22 @@ To browse the examples corresponding to released versions of 🤗 Transformers,
|
||||
<details>
|
||||
<summary>Examples for older versions of 🤗 Transformers</summary>
|
||||
<ul>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v4.21.0/examples">v4.21.0</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v4.20.1/examples">v4.20.1</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v4.19.4/examples">v4.19.4</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v4.18.0/examples">v4.18.0</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v4.17.0/examples">v4.17.0</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v4.16.2/examples">v4.16.2</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v4.15.0/examples">v4.15.0</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v4.14.1/examples">v4.14.1</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v4.13.0/examples">v4.13.0</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v4.12.5/examples">v4.12.5</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v4.11.3/examples">v4.11.3</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v4.10.3/examples">v4.10.3</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v4.9.2/examples">v4.9.2</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v4.8.2/examples">v4.8.2</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v4.7.0/examples">v4.7.0</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v4.6.1/examples">v4.6.1</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v4.5.1/examples">v4.5.1</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v4.4.2/examples">v4.4.2</a></li>
|
||||
<li><a href="https://github.com/huggingface/transformers/tree/v4.3.3/examples">v4.3.3</a></li>
|
||||
|
@ -3,3 +3,4 @@ jax>=0.2.8
|
||||
jaxlib>=0.1.59
|
||||
flax>=0.3.5
|
||||
optax>=0.0.8
|
||||
evaluate>=0.2.0
|
||||
|
@ -680,12 +680,9 @@ def main():
|
||||
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
|
||||
|
||||
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
|
||||
# Extract a few results from ROUGE
|
||||
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
|
||||
|
||||
result = {k: round(v * 100, 4) for k, v in result.items()}
|
||||
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
|
||||
result["gen_len"] = np.mean(prediction_lens)
|
||||
result = {k: round(v, 4) for k, v in result.items()}
|
||||
return result
|
||||
|
||||
# Enable tensorboard only on the master node
|
||||
|
@ -23,4 +23,4 @@ torchvision
|
||||
jiwer
|
||||
librosa
|
||||
torch < 1.12
|
||||
evaluate
|
||||
evaluate >= 0.2.0
|
||||
|
@ -48,10 +48,13 @@ from transformers import (
|
||||
SchedulerType,
|
||||
get_scheduler,
|
||||
)
|
||||
from transformers.utils import get_full_repo_name, send_example_telemetry
|
||||
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.22.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||
|
@ -52,10 +52,13 @@ from transformers import (
|
||||
default_data_collator,
|
||||
get_scheduler,
|
||||
)
|
||||
from transformers.utils import get_full_repo_name, send_example_telemetry
|
||||
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.22.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
@ -52,10 +52,13 @@ from transformers import (
|
||||
SchedulerType,
|
||||
get_scheduler,
|
||||
)
|
||||
from transformers.utils import get_full_repo_name, send_example_telemetry
|
||||
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.22.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
|
||||
|
@ -52,9 +52,12 @@ from transformers import (
|
||||
default_data_collator,
|
||||
get_scheduler,
|
||||
)
|
||||
from transformers.utils import PaddingStrategy, get_full_repo_name, send_example_telemetry
|
||||
from transformers.utils import PaddingStrategy, check_min_version, get_full_repo_name, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.22.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
# You should update this to your particular problem to have better documentation of `model_type`
|
||||
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
|
||||
|
@ -45,10 +45,13 @@ from transformers import (
|
||||
default_data_collator,
|
||||
get_scheduler,
|
||||
)
|
||||
from transformers.utils import get_full_repo_name, send_example_telemetry
|
||||
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.22.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
|
||||
|
@ -625,12 +625,9 @@ def main():
|
||||
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
|
||||
|
||||
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
|
||||
# Extract a few results from ROUGE
|
||||
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
|
||||
|
||||
result = {k: round(v * 100, 4) for k, v in result.items()}
|
||||
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
|
||||
result["gen_len"] = np.mean(prediction_lens)
|
||||
result = {k: round(v, 4) for k, v in result.items()}
|
||||
return result
|
||||
|
||||
# Initialize our Trainer
|
||||
|
@ -51,10 +51,13 @@ from transformers import (
|
||||
SchedulerType,
|
||||
get_scheduler,
|
||||
)
|
||||
from transformers.utils import get_full_repo_name, is_offline_mode, send_example_telemetry
|
||||
from transformers.utils import check_min_version, get_full_repo_name, is_offline_mode, send_example_telemetry
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.22.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
@ -707,10 +710,7 @@ def main():
|
||||
references=decoded_labels,
|
||||
)
|
||||
result = metric.compute(use_stemmer=True)
|
||||
# Extract a few results from ROUGE
|
||||
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
|
||||
|
||||
result = {k: round(v, 4) for k, v in result.items()}
|
||||
result = {k: round(v * 100, 4) for k, v in result.items()}
|
||||
|
||||
logger.info(result)
|
||||
|
||||
|
@ -43,10 +43,13 @@ from transformers import (
|
||||
default_data_collator,
|
||||
get_scheduler,
|
||||
)
|
||||
from transformers.utils import get_full_repo_name, send_example_telemetry
|
||||
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.22.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
@ -50,10 +50,13 @@ from transformers import (
|
||||
default_data_collator,
|
||||
get_scheduler,
|
||||
)
|
||||
from transformers.utils import get_full_repo_name, send_example_telemetry
|
||||
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.22.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
|
@ -52,10 +52,13 @@ from transformers import (
|
||||
default_data_collator,
|
||||
get_scheduler,
|
||||
)
|
||||
from transformers.utils import get_full_repo_name, send_example_telemetry
|
||||
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.22.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
||||
|
||||
|
@ -680,10 +680,7 @@ def main():
|
||||
metric.add_batch(predictions=decoded_preds, references=decoded_labels)
|
||||
|
||||
result = metric.compute(use_stemmer=True)
|
||||
# Extract a few results from ROUGE
|
||||
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
|
||||
|
||||
result = {k: round(v, 4) for k, v in result.items()}
|
||||
result = {k: round(v * 100, 4) for k, v in result.items()}
|
||||
|
||||
logger.info(result)
|
||||
# endregion
|
||||
|
2
setup.py
2
setup.py
@ -105,7 +105,7 @@ _deps = [
|
||||
"datasets",
|
||||
"deepspeed>=0.6.5",
|
||||
"dill<0.3.5",
|
||||
"evaluate",
|
||||
"evaluate>=0.2.0",
|
||||
"fairscale>0.3",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
|
@ -11,7 +11,7 @@ deps = {
|
||||
"datasets": "datasets",
|
||||
"deepspeed": "deepspeed>=0.6.5",
|
||||
"dill": "dill<0.3.5",
|
||||
"evaluate": "evaluate",
|
||||
"evaluate": "evaluate>=0.2.0",
|
||||
"fairscale": "fairscale>0.3",
|
||||
"faiss-cpu": "faiss-cpu",
|
||||
"fastapi": "fastapi",
|
||||
|
Loading…
Reference in New Issue
Block a user