mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Enable tqdm toggling (#15167)
* feature: enable tqdm toggle * test: add tqdm unit test * style: run linter * Update tests/test_tqdm_utils.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * refactor: use tiny model, run linter * docs: add tqdm to logging * docs: add tqdm reference to `http_get` * style: run linter * Update docs/source/main_classes/logging.mdx Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * fix: use `AutoConfig` for framework agnostic testing * chore: mv tqdm test to `test_logging.py` * feature: implement enable/disable functions * docs: mv docstring to comment * chore: mv tqdm functions to `logging.py` * docs: update docs to reference `enable/disable` funcs * test: update test to use `enable/disable` func * chore: update function reference in comment Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
parent
2c335037bd
commit
fe78fe98ca
@ -54,6 +54,8 @@ verbose to the most verbose), those levels (with their corresponding int values
|
||||
- `transformers.logging.INFO` (int value, 20): reports error, warnings and basic information.
|
||||
- `transformers.logging.DEBUG` (int value, 10): report all information.
|
||||
|
||||
By default, `tqdm` progress bars will be displayed during model download. [`logging.disable_progress_bar`] and [`logging.enable_progress_bar`] can be used to suppress or unsuppress this behavior.
|
||||
|
||||
## Base setters
|
||||
|
||||
[[autodoc]] logging.set_verbosity_error
|
||||
@ -79,3 +81,7 @@ verbose to the most verbose), those levels (with their corresponding int values
|
||||
[[autodoc]] logging.enable_explicit_format
|
||||
|
||||
[[autodoc]] logging.reset_format
|
||||
|
||||
[[autodoc]] logging.enable_progress_bar
|
||||
|
||||
[[autodoc]] logging.disable_progress_bar
|
||||
|
@ -45,12 +45,12 @@ from zipfile import ZipFile, is_zipfile
|
||||
|
||||
import numpy as np
|
||||
from packaging import version
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import requests
|
||||
from filelock import FileLock
|
||||
from huggingface_hub import HfFolder, Repository, create_repo, list_repo_files, whoami
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers.utils.logging import tqdm
|
||||
from transformers.utils.versions import importlib_metadata
|
||||
|
||||
from . import __version__
|
||||
@ -1911,6 +1911,8 @@ def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers
|
||||
r.raise_for_status()
|
||||
content_length = r.headers.get("Content-Length")
|
||||
total = resume_size + int(content_length) if content_length is not None else None
|
||||
# `tqdm` behavior is determined by `utils.logging.is_progress_bar_enabled()`
|
||||
# and can be set using `utils.logging.enable/disable_progress_bar()`
|
||||
progress = tqdm(
|
||||
unit="B",
|
||||
unit_scale=True,
|
||||
@ -1918,7 +1920,6 @@ def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers
|
||||
total=total,
|
||||
initial=resume_size,
|
||||
desc="Downloading",
|
||||
disable=bool(logging.get_verbosity() == logging.NOTSET),
|
||||
)
|
||||
for chunk in r.iter_content(chunk_size=1024):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
|
@ -28,6 +28,8 @@ from logging import WARN # NOQA
|
||||
from logging import WARNING # NOQA
|
||||
from typing import Optional
|
||||
|
||||
from tqdm import auto as tqdm_lib
|
||||
|
||||
|
||||
_lock = threading.Lock()
|
||||
_default_handler: Optional[logging.Handler] = None
|
||||
@ -42,6 +44,8 @@ log_levels = {
|
||||
|
||||
_default_log_level = logging.WARNING
|
||||
|
||||
_tqdm_active = True
|
||||
|
||||
|
||||
def _get_default_logging_level():
|
||||
"""
|
||||
@ -276,3 +280,65 @@ def warning_advice(self, *args, **kwargs):
|
||||
|
||||
|
||||
logging.Logger.warning_advice = warning_advice
|
||||
|
||||
|
||||
class EmptyTqdm:
|
||||
"""Dummy tqdm which doesn't do anything."""
|
||||
|
||||
def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
|
||||
self._iterator = args[0] if args else None
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._iterator)
|
||||
|
||||
def __getattr__(self, _):
|
||||
"""Return empty function."""
|
||||
|
||||
def empty_fn(*args, **kwargs): # pylint: disable=unused-argument
|
||||
return
|
||||
|
||||
return empty_fn
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, type_, value, traceback):
|
||||
return
|
||||
|
||||
|
||||
class _tqdm_cls:
|
||||
def __call__(self, *args, **kwargs):
|
||||
if _tqdm_active:
|
||||
return tqdm_lib.tqdm(*args, **kwargs)
|
||||
else:
|
||||
return EmptyTqdm(*args, **kwargs)
|
||||
|
||||
def set_lock(self, *args, **kwargs):
|
||||
self._lock = None
|
||||
if _tqdm_active:
|
||||
return tqdm_lib.tqdm.set_lock(*args, **kwargs)
|
||||
|
||||
def get_lock(self):
|
||||
if _tqdm_active:
|
||||
return tqdm_lib.tqdm.get_lock()
|
||||
|
||||
|
||||
tqdm = _tqdm_cls()
|
||||
|
||||
|
||||
def is_progress_bar_enabled() -> bool:
|
||||
"""Return a boolean indicating whether tqdm progress bars are enabled."""
|
||||
global _tqdm_active
|
||||
return bool(_tqdm_active)
|
||||
|
||||
|
||||
def enable_progress_bar():
|
||||
"""Enable tqdm progress bar."""
|
||||
global _tqdm_active
|
||||
_tqdm_active = True
|
||||
|
||||
|
||||
def disable_progress_bar():
|
||||
"""Enable tqdm progress bar."""
|
||||
global _tqdm_active
|
||||
_tqdm_active = False
|
||||
|
@ -14,10 +14,12 @@
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import transformers.models.bart.tokenization_bart
|
||||
from transformers import logging
|
||||
from transformers import AutoConfig, logging
|
||||
from transformers.testing_utils import CaptureLogger, mockenv, mockenv_context
|
||||
from transformers.utils.logging import disable_progress_bar, enable_progress_bar
|
||||
|
||||
|
||||
class HfArgumentParserTest(unittest.TestCase):
|
||||
@ -121,3 +123,17 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
with CaptureLogger(logger) as cl:
|
||||
logger.warning_advice(msg)
|
||||
self.assertEqual(cl.out, msg + "\n")
|
||||
|
||||
|
||||
def test_set_progress_bar_enabled():
|
||||
TINY_MODEL = "hf-internal-testing/tiny-random-distilbert"
|
||||
with patch("tqdm.auto.tqdm") as mock_tqdm:
|
||||
disable_progress_bar()
|
||||
_ = AutoConfig.from_pretrained(TINY_MODEL, force_download=True)
|
||||
mock_tqdm.assert_not_called()
|
||||
|
||||
mock_tqdm.reset_mock()
|
||||
|
||||
enable_progress_bar()
|
||||
_ = AutoConfig.from_pretrained(TINY_MODEL, force_download=True)
|
||||
mock_tqdm.assert_called()
|
||||
|
Loading…
Reference in New Issue
Block a user