mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Fix RequestCounter to make it more future-proof (#27406)
* Fix RequestCounter to make it more future-proof * code quality
This commit is contained in:
parent
c8b6052ff6
commit
e38348ae8f
@ -29,14 +29,15 @@ import sys
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union
|
||||
from unittest import mock
|
||||
from unittest.mock import patch
|
||||
|
||||
import huggingface_hub
|
||||
import requests
|
||||
import urllib3
|
||||
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
@ -1983,32 +1984,40 @@ def run_command(command: List[str], return_stdout=False):
|
||||
class RequestCounter:
|
||||
"""
|
||||
Helper class that will count all requests made online.
|
||||
|
||||
Might not be robust if urllib3 changes its logging format but should be good enough for us.
|
||||
|
||||
Usage:
|
||||
```py
|
||||
with RequestCounter() as counter:
|
||||
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
assert counter["GET"] == 0
|
||||
assert counter["HEAD"] == 1
|
||||
assert counter.total_calls == 1
|
||||
```
|
||||
"""
|
||||
|
||||
def __enter__(self):
|
||||
self.head_request_count = 0
|
||||
self.get_request_count = 0
|
||||
self.other_request_count = 0
|
||||
|
||||
# Mock `get_session` to count HTTP calls.
|
||||
self.old_get_session = huggingface_hub.utils._http.get_session
|
||||
self.session = requests.Session()
|
||||
self.session.request = self.new_request
|
||||
huggingface_hub.utils._http.get_session = lambda: self.session
|
||||
self._counter = defaultdict(int)
|
||||
self.patcher = patch.object(urllib3.connectionpool.log, "debug", wraps=urllib3.connectionpool.log.debug)
|
||||
self.mock = self.patcher.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
huggingface_hub.utils._http.get_session = self.old_get_session
|
||||
def __exit__(self, *args, **kwargs) -> None:
|
||||
for call in self.mock.call_args_list:
|
||||
log = call.args[0] % call.args[1:]
|
||||
for method in ("HEAD", "GET", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"):
|
||||
if method in log:
|
||||
self._counter[method] += 1
|
||||
break
|
||||
self.patcher.stop()
|
||||
|
||||
def new_request(self, method, **kwargs):
|
||||
if method == "GET":
|
||||
self.get_request_count += 1
|
||||
elif method == "HEAD":
|
||||
self.head_request_count += 1
|
||||
else:
|
||||
self.other_request_count += 1
|
||||
def __getitem__(self, key: str) -> int:
|
||||
return self._counter[key]
|
||||
|
||||
return requests.request(method=method, **kwargs)
|
||||
@property
|
||||
def total_calls(self) -> int:
|
||||
return sum(self._counter.values())
|
||||
|
||||
|
||||
def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None):
|
||||
|
@ -482,25 +482,22 @@ class AutoModelTest(unittest.TestCase):
|
||||
with self.assertRaisesRegex(EnvironmentError, "Use `from_flax=True` to load this model"):
|
||||
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||
|
||||
@unittest.skip(
|
||||
"Currently failing with new huggingface_hub release. See: https://github.com/huggingface/transformers/pull/27389"
|
||||
)
|
||||
def test_cached_model_has_minimum_calls_to_head(self):
|
||||
# Make sure we have cached the model.
|
||||
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
with RequestCounter() as counter:
|
||||
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
self.assertEqual(counter.get_request_count, 0)
|
||||
self.assertEqual(counter.head_request_count, 1)
|
||||
self.assertEqual(counter.other_request_count, 0)
|
||||
self.assertEqual(counter["GET"], 0)
|
||||
self.assertEqual(counter["HEAD"], 1)
|
||||
self.assertEqual(counter.total_calls, 1)
|
||||
|
||||
# With a sharded checkpoint
|
||||
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
|
||||
with RequestCounter() as counter:
|
||||
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
|
||||
self.assertEqual(counter.get_request_count, 0)
|
||||
self.assertEqual(counter.head_request_count, 1)
|
||||
self.assertEqual(counter.other_request_count, 0)
|
||||
self.assertEqual(counter["GET"], 0)
|
||||
self.assertEqual(counter["HEAD"], 1)
|
||||
self.assertEqual(counter.total_calls, 1)
|
||||
|
||||
def test_attr_not_existing(self):
|
||||
from transformers.models.auto.auto_factory import _LazyAutoMapping
|
||||
|
@ -301,14 +301,14 @@ class TFAutoModelTest(unittest.TestCase):
|
||||
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
with RequestCounter() as counter:
|
||||
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
self.assertEqual(counter.get_request_count, 0)
|
||||
self.assertEqual(counter.head_request_count, 1)
|
||||
self.assertEqual(counter.other_request_count, 0)
|
||||
self.assertEqual(counter["GET"], 0)
|
||||
self.assertEqual(counter["HEAD"], 1)
|
||||
self.assertEqual(counter.total_calls, 1)
|
||||
|
||||
# With a sharded checkpoint
|
||||
_ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
|
||||
with RequestCounter() as counter:
|
||||
_ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
|
||||
self.assertEqual(counter.get_request_count, 0)
|
||||
self.assertEqual(counter.head_request_count, 1)
|
||||
self.assertEqual(counter.other_request_count, 0)
|
||||
self.assertEqual(counter["GET"], 0)
|
||||
self.assertEqual(counter["HEAD"], 1)
|
||||
self.assertEqual(counter.total_calls, 1)
|
||||
|
@ -419,14 +419,11 @@ class AutoTokenizerTest(unittest.TestCase):
|
||||
):
|
||||
_ = AutoTokenizer.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")
|
||||
|
||||
@unittest.skip(
|
||||
"Currently failing with new huggingface_hub release. See: https://github.com/huggingface/transformers/pull/27389"
|
||||
)
|
||||
def test_cached_tokenizer_has_minimum_calls_to_head(self):
|
||||
# Make sure we have cached the tokenizer.
|
||||
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
with RequestCounter() as counter:
|
||||
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
self.assertEqual(counter.get_request_count, 0)
|
||||
self.assertEqual(counter.head_request_count, 1)
|
||||
self.assertEqual(counter.other_request_count, 0)
|
||||
self.assertEqual(counter["GET"], 0)
|
||||
self.assertEqual(counter["HEAD"], 1)
|
||||
self.assertEqual(counter.total_calls, 1)
|
||||
|
@ -763,9 +763,9 @@ class CustomPipelineTest(unittest.TestCase):
|
||||
_ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert")
|
||||
with RequestCounter() as counter:
|
||||
_ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert")
|
||||
self.assertEqual(counter.get_request_count, 0)
|
||||
self.assertEqual(counter.head_request_count, 1)
|
||||
self.assertEqual(counter.other_request_count, 0)
|
||||
self.assertEqual(counter["GET"], 0)
|
||||
self.assertEqual(counter["HEAD"], 1)
|
||||
self.assertEqual(counter.total_calls, 1)
|
||||
|
||||
@require_torch
|
||||
def test_chunk_pipeline_batching_single_file(self):
|
||||
|
Loading…
Reference in New Issue
Block a user