mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
adding proxies options for the from_pretrained methods
This commit is contained in:
parent
6d0aa73981
commit
43489756ad
4
.gitignore
vendored
4
.gitignore
vendored
@ -127,4 +127,6 @@ proc_data
|
||||
|
||||
# examples
|
||||
runs
|
||||
examples/runs
|
||||
examples/runs
|
||||
|
||||
data
|
@ -17,8 +17,9 @@ from hashlib import sha256
|
||||
from io import open
|
||||
|
||||
import boto3
|
||||
import requests
|
||||
from botocore.config import Config
|
||||
from botocore.exceptions import ClientError
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
try:
|
||||
@ -93,7 +94,7 @@ def filename_to_url(filename, cache_dir=None):
|
||||
return url, etag
|
||||
|
||||
|
||||
def cached_path(url_or_filename, cache_dir=None, force_download=False):
|
||||
def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None):
|
||||
"""
|
||||
Given something that might be a URL (or might be a local path),
|
||||
determine which. If it's a URL, download the file and cache it, and
|
||||
@ -114,7 +115,7 @@ def cached_path(url_or_filename, cache_dir=None, force_download=False):
|
||||
|
||||
if parsed.scheme in ('http', 'https', 's3'):
|
||||
# URL, so get it from the cache (downloading if necessary)
|
||||
return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download)
|
||||
return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
||||
elif os.path.exists(url_or_filename):
|
||||
# File, and it exists.
|
||||
return url_or_filename
|
||||
@ -159,24 +160,24 @@ def s3_request(func):
|
||||
|
||||
|
||||
@s3_request
|
||||
def s3_etag(url):
|
||||
def s3_etag(url, proxies=None):
|
||||
"""Check ETag on S3 object."""
|
||||
s3_resource = boto3.resource("s3")
|
||||
s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
|
||||
bucket_name, s3_path = split_s3_path(url)
|
||||
s3_object = s3_resource.Object(bucket_name, s3_path)
|
||||
return s3_object.e_tag
|
||||
|
||||
|
||||
@s3_request
|
||||
def s3_get(url, temp_file):
|
||||
def s3_get(url, temp_file, proxies=None):
|
||||
"""Pull a file directly from S3."""
|
||||
s3_resource = boto3.resource("s3")
|
||||
s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
|
||||
bucket_name, s3_path = split_s3_path(url)
|
||||
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
|
||||
|
||||
|
||||
def http_get(url, temp_file):
|
||||
req = requests.get(url, stream=True)
|
||||
def http_get(url, temp_file, proxies=None):
|
||||
req = requests.get(url, stream=True, proxies=proxies)
|
||||
content_length = req.headers.get('Content-Length')
|
||||
total = int(content_length) if content_length is not None else None
|
||||
progress = tqdm(unit="B", total=total)
|
||||
@ -187,7 +188,7 @@ def http_get(url, temp_file):
|
||||
progress.close()
|
||||
|
||||
|
||||
def get_from_cache(url, cache_dir=None, force_download=False):
|
||||
def get_from_cache(url, cache_dir=None, force_download=False, proxies=None):
|
||||
"""
|
||||
Given a URL, look for the corresponding dataset in the local cache.
|
||||
If it's not there, download it. Then return the path to the cached file.
|
||||
@ -204,10 +205,10 @@ def get_from_cache(url, cache_dir=None, force_download=False):
|
||||
|
||||
# Get eTag to add to filename, if it exists.
|
||||
if url.startswith("s3://"):
|
||||
etag = s3_etag(url)
|
||||
etag = s3_etag(url, proxies=proxies)
|
||||
else:
|
||||
try:
|
||||
response = requests.head(url, allow_redirects=True)
|
||||
response = requests.head(url, allow_redirects=True, proxies=proxies)
|
||||
if response.status_code != 200:
|
||||
etag = None
|
||||
else:
|
||||
@ -238,9 +239,9 @@ def get_from_cache(url, cache_dir=None, force_download=False):
|
||||
|
||||
# GET file object
|
||||
if url.startswith("s3://"):
|
||||
s3_get(url, temp_file)
|
||||
s3_get(url, temp_file, proxies=proxies)
|
||||
else:
|
||||
http_get(url, temp_file)
|
||||
http_get(url, temp_file, proxies=proxies)
|
||||
|
||||
# we are copying the file before closing it, so flush to avoid truncation
|
||||
temp_file.flush()
|
||||
|
@ -128,6 +128,10 @@ class PretrainedConfig(object):
|
||||
force_download: (`optional`) boolean, default False:
|
||||
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||
|
||||
proxies: (`optional`) dict, default None:
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||
The proxies are used on each request.
|
||||
|
||||
return_unused_kwargs: (`optional`) bool:
|
||||
|
||||
- If False, then this function returns just the final configuration object.
|
||||
@ -150,6 +154,7 @@ class PretrainedConfig(object):
|
||||
"""
|
||||
cache_dir = kwargs.pop('cache_dir', None)
|
||||
force_download = kwargs.pop('force_download', False)
|
||||
proxies = kwargs.pop('proxies', None)
|
||||
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
|
||||
|
||||
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
||||
@ -160,7 +165,7 @@ class PretrainedConfig(object):
|
||||
config_file = pretrained_model_name_or_path
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download)
|
||||
resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
||||
except EnvironmentError:
|
||||
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
||||
logger.error(
|
||||
@ -407,6 +412,10 @@ class PreTrainedModel(nn.Module):
|
||||
force_download: (`optional`) boolean, default False:
|
||||
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||
|
||||
proxies: (`optional`) dict, default None:
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||
The proxies are used on each request.
|
||||
|
||||
output_loading_info: (`optional`) boolean:
|
||||
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
|
||||
|
||||
@ -432,6 +441,7 @@ class PreTrainedModel(nn.Module):
|
||||
cache_dir = kwargs.pop('cache_dir', None)
|
||||
from_tf = kwargs.pop('from_tf', False)
|
||||
force_download = kwargs.pop('force_download', False)
|
||||
proxies = kwargs.pop('proxies', None)
|
||||
output_loading_info = kwargs.pop('output_loading_info', False)
|
||||
|
||||
# Load config
|
||||
@ -462,7 +472,7 @@ class PreTrainedModel(nn.Module):
|
||||
archive_file = pretrained_model_name_or_path
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download)
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
||||
except EnvironmentError:
|
||||
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
||||
logger.error(
|
||||
|
@ -196,6 +196,10 @@ class PreTrainedTokenizer(object):
|
||||
force_download: (`optional`) boolean, default False:
|
||||
Force to (re-)download the vocabulary files and override the cached versions if they exists.
|
||||
|
||||
proxies: (`optional`) dict, default None:
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||
The proxies are used on each request.
|
||||
|
||||
inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method.
|
||||
|
||||
kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~pytorch_transformers.PreTrainedTokenizer` for details.
|
||||
@ -227,6 +231,7 @@ class PreTrainedTokenizer(object):
|
||||
def _from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
||||
cache_dir = kwargs.pop('cache_dir', None)
|
||||
force_download = kwargs.pop('force_download', False)
|
||||
proxies = kwargs.pop('proxies', None)
|
||||
|
||||
s3_models = list(cls.max_model_input_sizes.keys())
|
||||
vocab_files = {}
|
||||
@ -287,7 +292,7 @@ class PreTrainedTokenizer(object):
|
||||
if file_path is None:
|
||||
resolved_vocab_files[file_id] = None
|
||||
else:
|
||||
resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, force_download=force_download)
|
||||
resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
||||
except EnvironmentError:
|
||||
if pretrained_model_name_or_path in s3_models:
|
||||
logger.error("Couldn't reach server to download vocabulary.")
|
||||
|
Loading…
Reference in New Issue
Block a user