Enable tqdm toggling (#15167)

* feature: enable tqdm toggle

* test: add tqdm unit test

* style: run linter

* Update tests/test_tqdm_utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* refactor: use tiny model, run linter

* docs: add tqdm to logging

* docs: add tqdm reference to `http_get`

* style: run linter

* Update docs/source/main_classes/logging.mdx

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* fix: use `AutoConfig` for framework agnostic testing

* chore: mv tqdm test to `test_logging.py`

* feature: implement enable/disable functions

* docs: mv docstring to comment

* chore: mv tqdm functions to `logging.py`

* docs: update docs to reference `enable/disable` funcs

* test: update test to use `enable/disable` func

* chore: update function reference in comment

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
Jake Tae 2022-01-19 07:52:35 +09:00 committed by GitHub
parent 2c335037bd
commit fe78fe98ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 92 additions and 3 deletions

View File

@ -54,6 +54,8 @@ verbose to the most verbose), those levels (with their corresponding int values
- `transformers.logging.INFO` (int value, 20): reports error, warnings and basic information.
- `transformers.logging.DEBUG` (int value, 10): report all information.
By default, `tqdm` progress bars will be displayed during model download. [`logging.disable_progress_bar`] and [`logging.enable_progress_bar`] can be used to suppress or unsuppress this behavior.
## Base setters
[[autodoc]] logging.set_verbosity_error
@ -79,3 +81,7 @@ verbose to the most verbose), those levels (with their corresponding int values
[[autodoc]] logging.enable_explicit_format
[[autodoc]] logging.reset_format
[[autodoc]] logging.enable_progress_bar
[[autodoc]] logging.disable_progress_bar

View File

@ -45,12 +45,12 @@ from zipfile import ZipFile, is_zipfile
import numpy as np
from packaging import version
from tqdm.auto import tqdm
import requests
from filelock import FileLock
from huggingface_hub import HfFolder, Repository, create_repo, list_repo_files, whoami
from requests.exceptions import HTTPError
from transformers.utils.logging import tqdm
from transformers.utils.versions import importlib_metadata
from . import __version__
@ -1911,6 +1911,8 @@ def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers
r.raise_for_status()
content_length = r.headers.get("Content-Length")
total = resume_size + int(content_length) if content_length is not None else None
# `tqdm` behavior is determined by `utils.logging.is_progress_bar_enabled()`
# and can be set using `utils.logging.enable/disable_progress_bar()`
progress = tqdm(
unit="B",
unit_scale=True,
@ -1918,7 +1920,6 @@ def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers
total=total,
initial=resume_size,
desc="Downloading",
disable=bool(logging.get_verbosity() == logging.NOTSET),
)
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks

View File

@ -28,6 +28,8 @@ from logging import WARN # NOQA
from logging import WARNING # NOQA
from typing import Optional
from tqdm import auto as tqdm_lib
_lock = threading.Lock()
_default_handler: Optional[logging.Handler] = None
@ -42,6 +44,8 @@ log_levels = {
_default_log_level = logging.WARNING
_tqdm_active = True
def _get_default_logging_level():
"""
@ -276,3 +280,65 @@ def warning_advice(self, *args, **kwargs):
logging.Logger.warning_advice = warning_advice
class EmptyTqdm:
"""Dummy tqdm which doesn't do anything."""
def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
self._iterator = args[0] if args else None
def __iter__(self):
return iter(self._iterator)
def __getattr__(self, _):
"""Return empty function."""
def empty_fn(*args, **kwargs): # pylint: disable=unused-argument
return
return empty_fn
def __enter__(self):
return self
def __exit__(self, type_, value, traceback):
return
class _tqdm_cls:
def __call__(self, *args, **kwargs):
if _tqdm_active:
return tqdm_lib.tqdm(*args, **kwargs)
else:
return EmptyTqdm(*args, **kwargs)
def set_lock(self, *args, **kwargs):
self._lock = None
if _tqdm_active:
return tqdm_lib.tqdm.set_lock(*args, **kwargs)
def get_lock(self):
if _tqdm_active:
return tqdm_lib.tqdm.get_lock()
tqdm = _tqdm_cls()
def is_progress_bar_enabled() -> bool:
"""Return a boolean indicating whether tqdm progress bars are enabled."""
global _tqdm_active
return bool(_tqdm_active)
def enable_progress_bar():
"""Enable tqdm progress bar."""
global _tqdm_active
_tqdm_active = True
def disable_progress_bar():
"""Enable tqdm progress bar."""
global _tqdm_active
_tqdm_active = False

View File

@ -14,10 +14,12 @@
import os
import unittest
from unittest.mock import patch
import transformers.models.bart.tokenization_bart
from transformers import logging
from transformers import AutoConfig, logging
from transformers.testing_utils import CaptureLogger, mockenv, mockenv_context
from transformers.utils.logging import disable_progress_bar, enable_progress_bar
class HfArgumentParserTest(unittest.TestCase):
@ -121,3 +123,17 @@ class HfArgumentParserTest(unittest.TestCase):
with CaptureLogger(logger) as cl:
logger.warning_advice(msg)
self.assertEqual(cl.out, msg + "\n")
def test_set_progress_bar_enabled():
TINY_MODEL = "hf-internal-testing/tiny-random-distilbert"
with patch("tqdm.auto.tqdm") as mock_tqdm:
disable_progress_bar()
_ = AutoConfig.from_pretrained(TINY_MODEL, force_download=True)
mock_tqdm.assert_not_called()
mock_tqdm.reset_mock()
enable_progress_bar()
_ = AutoConfig.from_pretrained(TINY_MODEL, force_download=True)
mock_tqdm.assert_called()