Enable tqdm toggling (#15167)

* feature: enable tqdm toggle

* test: add tqdm unit test

* style: run linter

* Update tests/test_tqdm_utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* refactor: use tiny model, run linter

* docs: add tqdm to logging

* docs: add tqdm reference to `http_get`

* style: run linter

* Update docs/source/main_classes/logging.mdx

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* fix: use `AutoConfig` for framework agnostic testing

* chore: mv tqdm test to `test_logging.py`

* feature: implement enable/disable functions

* docs: mv docstring to comment

* chore: mv tqdm functions to `logging.py`

* docs: update docs to reference `enable/disable` funcs

* test: update test to use `enable/disable` func

* chore: update function reference in comment

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
Jake Tae 2022-01-19 07:52:35 +09:00 committed by GitHub
parent 2c335037bd
commit fe78fe98ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 92 additions and 3 deletions

View File

@ -54,6 +54,8 @@ verbose to the most verbose), those levels (with their corresponding int values
- `transformers.logging.INFO` (int value, 20): reports error, warnings and basic information. - `transformers.logging.INFO` (int value, 20): reports error, warnings and basic information.
- `transformers.logging.DEBUG` (int value, 10): report all information. - `transformers.logging.DEBUG` (int value, 10): report all information.
By default, `tqdm` progress bars will be displayed during model download. [`logging.disable_progress_bar`] and [`logging.enable_progress_bar`] can be used to suppress or unsuppress this behavior.
## Base setters ## Base setters
[[autodoc]] logging.set_verbosity_error [[autodoc]] logging.set_verbosity_error
@ -79,3 +81,7 @@ verbose to the most verbose), those levels (with their corresponding int values
[[autodoc]] logging.enable_explicit_format [[autodoc]] logging.enable_explicit_format
[[autodoc]] logging.reset_format [[autodoc]] logging.reset_format
[[autodoc]] logging.enable_progress_bar
[[autodoc]] logging.disable_progress_bar

View File

@ -45,12 +45,12 @@ from zipfile import ZipFile, is_zipfile
import numpy as np import numpy as np
from packaging import version from packaging import version
from tqdm.auto import tqdm
import requests import requests
from filelock import FileLock from filelock import FileLock
from huggingface_hub import HfFolder, Repository, create_repo, list_repo_files, whoami from huggingface_hub import HfFolder, Repository, create_repo, list_repo_files, whoami
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers.utils.logging import tqdm
from transformers.utils.versions import importlib_metadata from transformers.utils.versions import importlib_metadata
from . import __version__ from . import __version__
@ -1911,6 +1911,8 @@ def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers
r.raise_for_status() r.raise_for_status()
content_length = r.headers.get("Content-Length") content_length = r.headers.get("Content-Length")
total = resume_size + int(content_length) if content_length is not None else None total = resume_size + int(content_length) if content_length is not None else None
# `tqdm` behavior is determined by `utils.logging.is_progress_bar_enabled()`
# and can be set using `utils.logging.enable/disable_progress_bar()`
progress = tqdm( progress = tqdm(
unit="B", unit="B",
unit_scale=True, unit_scale=True,
@ -1918,7 +1920,6 @@ def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers
total=total, total=total,
initial=resume_size, initial=resume_size,
desc="Downloading", desc="Downloading",
disable=bool(logging.get_verbosity() == logging.NOTSET),
) )
for chunk in r.iter_content(chunk_size=1024): for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks if chunk: # filter out keep-alive new chunks

View File

@ -28,6 +28,8 @@ from logging import WARN # NOQA
from logging import WARNING # NOQA from logging import WARNING # NOQA
from typing import Optional from typing import Optional
from tqdm import auto as tqdm_lib
_lock = threading.Lock() _lock = threading.Lock()
_default_handler: Optional[logging.Handler] = None _default_handler: Optional[logging.Handler] = None
@ -42,6 +44,8 @@ log_levels = {
_default_log_level = logging.WARNING _default_log_level = logging.WARNING
_tqdm_active = True
def _get_default_logging_level(): def _get_default_logging_level():
""" """
@ -276,3 +280,65 @@ def warning_advice(self, *args, **kwargs):
logging.Logger.warning_advice = warning_advice logging.Logger.warning_advice = warning_advice
class EmptyTqdm:
"""Dummy tqdm which doesn't do anything."""
def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
self._iterator = args[0] if args else None
def __iter__(self):
return iter(self._iterator)
def __getattr__(self, _):
"""Return empty function."""
def empty_fn(*args, **kwargs): # pylint: disable=unused-argument
return
return empty_fn
def __enter__(self):
return self
def __exit__(self, type_, value, traceback):
return
class _tqdm_cls:
def __call__(self, *args, **kwargs):
if _tqdm_active:
return tqdm_lib.tqdm(*args, **kwargs)
else:
return EmptyTqdm(*args, **kwargs)
def set_lock(self, *args, **kwargs):
self._lock = None
if _tqdm_active:
return tqdm_lib.tqdm.set_lock(*args, **kwargs)
def get_lock(self):
if _tqdm_active:
return tqdm_lib.tqdm.get_lock()
tqdm = _tqdm_cls()
def is_progress_bar_enabled() -> bool:
"""Return a boolean indicating whether tqdm progress bars are enabled."""
global _tqdm_active
return bool(_tqdm_active)
def enable_progress_bar():
"""Enable tqdm progress bar."""
global _tqdm_active
_tqdm_active = True
def disable_progress_bar():
"""Enable tqdm progress bar."""
global _tqdm_active
_tqdm_active = False

View File

@ -14,10 +14,12 @@
import os import os
import unittest import unittest
from unittest.mock import patch
import transformers.models.bart.tokenization_bart import transformers.models.bart.tokenization_bart
from transformers import logging from transformers import AutoConfig, logging
from transformers.testing_utils import CaptureLogger, mockenv, mockenv_context from transformers.testing_utils import CaptureLogger, mockenv, mockenv_context
from transformers.utils.logging import disable_progress_bar, enable_progress_bar
class HfArgumentParserTest(unittest.TestCase): class HfArgumentParserTest(unittest.TestCase):
@ -121,3 +123,17 @@ class HfArgumentParserTest(unittest.TestCase):
with CaptureLogger(logger) as cl: with CaptureLogger(logger) as cl:
logger.warning_advice(msg) logger.warning_advice(msg)
self.assertEqual(cl.out, msg + "\n") self.assertEqual(cl.out, msg + "\n")
def test_set_progress_bar_enabled():
TINY_MODEL = "hf-internal-testing/tiny-random-distilbert"
with patch("tqdm.auto.tqdm") as mock_tqdm:
disable_progress_bar()
_ = AutoConfig.from_pretrained(TINY_MODEL, force_download=True)
mock_tqdm.assert_not_called()
mock_tqdm.reset_mock()
enable_progress_bar()
_ = AutoConfig.from_pretrained(TINY_MODEL, force_download=True)
mock_tqdm.assert_called()