Fix RequestCounter to make it more future-proof (#27406)

* Fix RequestCounter to make it more future-proof

* code quality
This commit is contained in:
Lucain 2023-11-09 18:53:26 +01:00 committed by GitHub
parent c8b6052ff6
commit e38348ae8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 48 additions and 45 deletions

View File

@ -29,14 +29,15 @@ import sys
import tempfile
import time
import unittest
from collections import defaultdict
from collections.abc import Mapping
from io import StringIO
from pathlib import Path
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union
from unittest import mock
from unittest.mock import patch
import huggingface_hub
import requests
import urllib3
from transformers import logging as transformers_logging
@ -1983,32 +1984,40 @@ def run_command(command: List[str], return_stdout=False):
class RequestCounter:
"""
Helper class that will count all requests made online.
Might not be robust if urllib3 changes its logging format but should be good enough for us.
Usage:
```py
with RequestCounter() as counter:
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
assert counter["GET"] == 0
assert counter["HEAD"] == 1
assert counter.total_calls == 1
```
"""
def __enter__(self):
self.head_request_count = 0
self.get_request_count = 0
self.other_request_count = 0
# Mock `get_session` to count HTTP calls.
self.old_get_session = huggingface_hub.utils._http.get_session
self.session = requests.Session()
self.session.request = self.new_request
huggingface_hub.utils._http.get_session = lambda: self.session
self._counter = defaultdict(int)
self.patcher = patch.object(urllib3.connectionpool.log, "debug", wraps=urllib3.connectionpool.log.debug)
self.mock = self.patcher.start()
return self
def __exit__(self, *args, **kwargs):
huggingface_hub.utils._http.get_session = self.old_get_session
def __exit__(self, *args, **kwargs) -> None:
for call in self.mock.call_args_list:
log = call.args[0] % call.args[1:]
for method in ("HEAD", "GET", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"):
if method in log:
self._counter[method] += 1
break
self.patcher.stop()
def new_request(self, method, **kwargs):
if method == "GET":
self.get_request_count += 1
elif method == "HEAD":
self.head_request_count += 1
else:
self.other_request_count += 1
def __getitem__(self, key: str) -> int:
return self._counter[key]
return requests.request(method=method, **kwargs)
@property
def total_calls(self) -> int:
return sum(self._counter.values())
def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None):

View File

@ -482,25 +482,22 @@ class AutoModelTest(unittest.TestCase):
with self.assertRaisesRegex(EnvironmentError, "Use `from_flax=True` to load this model"):
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
@unittest.skip(
"Currently failing with new huggingface_hub release. See: https://github.com/huggingface/transformers/pull/27389"
)
def test_cached_model_has_minimum_calls_to_head(self):
# Make sure we have cached the model.
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with RequestCounter() as counter:
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0)
self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0)
self.assertEqual(counter["GET"], 0)
self.assertEqual(counter["HEAD"], 1)
self.assertEqual(counter.total_calls, 1)
# With a sharded checkpoint
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
with RequestCounter() as counter:
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
self.assertEqual(counter.get_request_count, 0)
self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0)
self.assertEqual(counter["GET"], 0)
self.assertEqual(counter["HEAD"], 1)
self.assertEqual(counter.total_calls, 1)
def test_attr_not_existing(self):
from transformers.models.auto.auto_factory import _LazyAutoMapping

View File

@ -301,14 +301,14 @@ class TFAutoModelTest(unittest.TestCase):
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with RequestCounter() as counter:
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0)
self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0)
self.assertEqual(counter["GET"], 0)
self.assertEqual(counter["HEAD"], 1)
self.assertEqual(counter.total_calls, 1)
# With a sharded checkpoint
_ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
with RequestCounter() as counter:
_ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
self.assertEqual(counter.get_request_count, 0)
self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0)
self.assertEqual(counter["GET"], 0)
self.assertEqual(counter["HEAD"], 1)
self.assertEqual(counter.total_calls, 1)

View File

@ -419,14 +419,11 @@ class AutoTokenizerTest(unittest.TestCase):
):
_ = AutoTokenizer.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")
@unittest.skip(
"Currently failing with new huggingface_hub release. See: https://github.com/huggingface/transformers/pull/27389"
)
def test_cached_tokenizer_has_minimum_calls_to_head(self):
# Make sure we have cached the tokenizer.
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
with RequestCounter() as counter:
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0)
self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0)
self.assertEqual(counter["GET"], 0)
self.assertEqual(counter["HEAD"], 1)
self.assertEqual(counter.total_calls, 1)

View File

@ -763,9 +763,9 @@ class CustomPipelineTest(unittest.TestCase):
_ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert")
with RequestCounter() as counter:
_ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0)
self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0)
self.assertEqual(counter["GET"], 0)
self.assertEqual(counter["HEAD"], 1)
self.assertEqual(counter.total_calls, 1)
@require_torch
def test_chunk_pipeline_batching_single_file(self):