mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Remove sys.version_info[0] == 2 or 3.
This commit is contained in:
parent
8af25b1664
commit
798b3b3899
@ -24,7 +24,6 @@ import glob
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -104,12 +103,7 @@ class InputFeatures(object):
|
||||
|
||||
def read_swag_examples(input_file, is_training=True):
|
||||
with open(input_file, "r", encoding="utf-8") as f:
|
||||
reader = csv.reader(f)
|
||||
lines = []
|
||||
for line in reader:
|
||||
if sys.version_info[0] == 2:
|
||||
line = list(unicode(cell, "utf-8") for cell in line) # noqa: F821
|
||||
lines.append(line)
|
||||
lines = list(csv.reader(f))
|
||||
|
||||
if is_training and lines[0][-1] != "label":
|
||||
raise ValueError("For training, the input file must contain a label column.")
|
||||
|
@ -21,7 +21,6 @@ import glob
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from io import open
|
||||
from typing import List
|
||||
|
||||
@ -179,13 +178,7 @@ class SwagProcessor(DataProcessor):
|
||||
|
||||
def _read_csv(self, input_file):
|
||||
with open(input_file, "r", encoding="utf-8") as f:
|
||||
reader = csv.reader(f)
|
||||
lines = []
|
||||
for line in reader:
|
||||
if sys.version_info[0] == 2:
|
||||
line = list(unicode(cell, "utf-8") for cell in line) # noqa: F821
|
||||
lines.append(line)
|
||||
return lines
|
||||
return list(csv.reader(f))
|
||||
|
||||
def _create_examples(self, lines: List[List[str]], type: str):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
|
@ -18,6 +18,7 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
from io import open
|
||||
|
||||
@ -34,12 +35,6 @@ from transformers import (
|
||||
from transformers.tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES
|
||||
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
import cPickle as pickle
|
||||
else:
|
||||
import pickle
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# We do this to be able to load python 2 datasets pickles
|
||||
|
@ -18,7 +18,6 @@ import copy
|
||||
import csv
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from ...file_utils import is_tf_available, is_torch_available
|
||||
|
||||
@ -98,13 +97,7 @@ class DataProcessor(object):
|
||||
def _read_tsv(cls, input_file, quotechar=None):
|
||||
"""Reads a tab separated value file."""
|
||||
with open(input_file, "r", encoding="utf-8-sig") as f:
|
||||
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
|
||||
lines = []
|
||||
for line in reader:
|
||||
if sys.version_info[0] == 2:
|
||||
line = list(unicode(cell, "utf-8") for cell in line) # noqa: F821
|
||||
lines.append(line)
|
||||
return lines
|
||||
return list(csv.reader(f, delimiter="\t", quotechar=quotechar))
|
||||
|
||||
|
||||
class SingleSentenceClassificationProcessor(DataProcessor):
|
||||
|
@ -166,7 +166,7 @@ def filename_to_url(filename, cache_dir=None):
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = TRANSFORMERS_CACHE
|
||||
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
||||
if isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
|
||||
cache_path = os.path.join(cache_dir, filename)
|
||||
@ -201,9 +201,9 @@ def cached_path(
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = TRANSFORMERS_CACHE
|
||||
if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
|
||||
if isinstance(url_or_filename, Path):
|
||||
url_or_filename = str(url_or_filename)
|
||||
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
||||
if isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
|
||||
if is_remote_url(url_or_filename):
|
||||
@ -314,9 +314,7 @@ def get_from_cache(
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = TRANSFORMERS_CACHE
|
||||
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
if sys.version_info[0] == 2 and not isinstance(cache_dir, str):
|
||||
if isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
|
||||
if not os.path.exists(cache_dir):
|
||||
@ -335,8 +333,6 @@ def get_from_cache(
|
||||
except (EnvironmentError, requests.exceptions.Timeout):
|
||||
etag = None
|
||||
|
||||
if sys.version_info[0] == 2 and etag is not None:
|
||||
etag = etag.decode("utf-8")
|
||||
filename = url_to_filename(url, etag)
|
||||
|
||||
# get cache path to put the file
|
||||
@ -400,9 +396,6 @@ def get_from_cache(
|
||||
meta = {"url": url, "etag": etag}
|
||||
meta_path = cache_path + ".json"
|
||||
with open(meta_path, "w") as meta_file:
|
||||
output_string = json.dumps(meta)
|
||||
if sys.version_info[0] == 2 and isinstance(output_string, str):
|
||||
output_string = unicode(output_string, "utf-8") # noqa: F821
|
||||
meta_file.write(output_string)
|
||||
json.dump(meta, meta_file)
|
||||
|
||||
return cache_path
|
||||
|
@ -19,7 +19,6 @@
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -338,9 +337,7 @@ class BertIntermediate(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertIntermediate, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
if isinstance(config.hidden_act, str) or (
|
||||
sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode) # noqa: F821
|
||||
):
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||||
else:
|
||||
self.intermediate_act_fn = config.hidden_act
|
||||
@ -460,9 +457,7 @@ class BertPredictionHeadTransform(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertPredictionHeadTransform, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
if isinstance(config.hidden_act, str) or (
|
||||
sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode) # noqa: F821
|
||||
):
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.transform_act_fn = ACT2FN[config.hidden_act]
|
||||
else:
|
||||
self.transform_act_fn = config.hidden_act
|
||||
|
@ -17,7 +17,6 @@
|
||||
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
@ -311,9 +310,7 @@ class TFAlbertLayer(tf.keras.layers.Layer):
|
||||
config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="ffn"
|
||||
)
|
||||
|
||||
if isinstance(config.hidden_act, str) or (
|
||||
sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode) # noqa: F821
|
||||
):
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.activation = ACT2FN[config.hidden_act]
|
||||
else:
|
||||
self.activation = config.hidden_act
|
||||
@ -454,9 +451,7 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
config.embedding_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||
)
|
||||
if isinstance(config.hidden_act, str) or (
|
||||
sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode) # noqa: F821
|
||||
):
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.activation = ACT2FN[config.hidden_act]
|
||||
else:
|
||||
self.activation = config.hidden_act
|
||||
|
@ -17,7 +17,6 @@
|
||||
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
@ -310,9 +309,7 @@ class TFBertIntermediate(tf.keras.layers.Layer):
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||
)
|
||||
if isinstance(config.hidden_act, str) or (
|
||||
sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode) # noqa: F821
|
||||
):
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||||
else:
|
||||
self.intermediate_act_fn = config.hidden_act
|
||||
@ -417,9 +414,7 @@ class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||
)
|
||||
if isinstance(config.hidden_act, str) or (
|
||||
sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode) # noqa: F821
|
||||
):
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.transform_act_fn = ACT2FN[config.hidden_act]
|
||||
else:
|
||||
self.transform_act_fn = config.hidden_act
|
||||
|
@ -18,7 +18,6 @@
|
||||
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
@ -290,9 +289,7 @@ class TFXLNetFeedForward(tf.keras.layers.Layer):
|
||||
config.d_model, kernel_initializer=get_initializer(config.initializer_range), name="layer_2"
|
||||
)
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||
if isinstance(config.ff_activation, str) or (
|
||||
sys.version_info[0] == 2 and isinstance(config.ff_activation, unicode) # noqa: F821
|
||||
):
|
||||
if isinstance(config.ff_activation, str):
|
||||
self.activation_function = ACT2FN[config.ff_activation]
|
||||
else:
|
||||
self.activation_function = config.ff_activation
|
||||
|
@ -19,7 +19,6 @@
|
||||
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -420,9 +419,7 @@ class XLNetFeedForward(nn.Module):
|
||||
self.layer_1 = nn.Linear(config.d_model, config.d_inner)
|
||||
self.layer_2 = nn.Linear(config.d_inner, config.d_model)
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
if isinstance(config.ff_activation, str) or (
|
||||
sys.version_info[0] == 2 and isinstance(config.ff_activation, unicode) # noqa: F821
|
||||
):
|
||||
if isinstance(config.ff_activation, str):
|
||||
self.activation_function = ACT2FN[config.ff_activation]
|
||||
else:
|
||||
self.activation_function = config.ff_activation
|
||||
|
@ -18,7 +18,6 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from io import open
|
||||
|
||||
import regex as re
|
||||
@ -80,7 +79,6 @@ def bytes_to_unicode():
|
||||
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
"""
|
||||
_chr = unichr if sys.version_info[0] == 2 else chr # noqa: F821
|
||||
bs = (
|
||||
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
||||
)
|
||||
@ -91,7 +89,7 @@ def bytes_to_unicode():
|
||||
bs.append(b)
|
||||
cs.append(2 ** 8 + n)
|
||||
n += 1
|
||||
cs = [_chr(n) for n in cs]
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
|
||||
@ -212,14 +210,9 @@ class GPT2Tokenizer(PreTrainedTokenizer):
|
||||
|
||||
bpe_tokens = []
|
||||
for token in re.findall(self.pat, text):
|
||||
if sys.version_info[0] == 2:
|
||||
token = "".join(
|
||||
self.byte_encoder[ord(b)] for b in token
|
||||
) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
|
||||
else:
|
||||
token = "".join(
|
||||
self.byte_encoder[b] for b in token.encode("utf-8")
|
||||
) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
|
||||
token = "".join(
|
||||
self.byte_encoder[b] for b in token.encode("utf-8")
|
||||
) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
|
||||
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
|
||||
return bpe_tokens
|
||||
|
||||
|
@ -21,7 +21,7 @@
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import pickle
|
||||
from collections import Counter, OrderedDict
|
||||
from io import open
|
||||
|
||||
@ -36,11 +36,6 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
import cPickle as pickle
|
||||
else:
|
||||
import pickle
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -16,8 +16,7 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from .test_tokenization_common import TemporaryDirectory
|
||||
import tempfile
|
||||
|
||||
|
||||
class ConfigTester(object):
|
||||
@ -42,7 +41,7 @@ class ConfigTester(object):
|
||||
def create_and_test_config_to_json_file(self):
|
||||
config_first = self.config_class(**self.inputs_dict)
|
||||
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
json_file_path = os.path.join(tmpdirname, "config.json")
|
||||
config_first.to_json_file(json_file_path)
|
||||
config_second = self.config_class.from_json_file(json_file_path)
|
||||
@ -52,7 +51,7 @@ class ConfigTester(object):
|
||||
def create_and_test_config_from_and_save_pretrained(self):
|
||||
config_first = self.config_class(**self.inputs_dict)
|
||||
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
config_first.save_pretrained(tmpdirname)
|
||||
config_second = self.config_class.from_pretrained(tmpdirname)
|
||||
|
||||
|
@ -16,12 +16,11 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers.modelcard import ModelCard
|
||||
|
||||
from .test_tokenization_common import TemporaryDirectory
|
||||
|
||||
|
||||
class ModelCardTester(unittest.TestCase):
|
||||
def setUp(self):
|
||||
@ -65,7 +64,7 @@ class ModelCardTester(unittest.TestCase):
|
||||
def test_model_card_to_json_file(self):
|
||||
model_card_first = ModelCard.from_dict(self.inputs_dict)
|
||||
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
filename = os.path.join(tmpdirname, "modelcard.json")
|
||||
model_card_first.to_json_file(filename)
|
||||
model_card_second = ModelCard.from_json_file(filename)
|
||||
@ -75,7 +74,7 @@ class ModelCardTester(unittest.TestCase):
|
||||
def test_model_card_from_and_save_pretrained(self):
|
||||
model_card_first = ModelCard.from_dict(self.inputs_dict)
|
||||
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model_card_first.save_pretrained(tmpdirname)
|
||||
model_card_second = ModelCard.from_pretrained(tmpdirname)
|
||||
|
||||
|
@ -19,8 +19,6 @@ import json
|
||||
import logging
|
||||
import os.path
|
||||
import random
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
import uuid
|
||||
@ -43,23 +41,6 @@ if is_torch_available():
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
|
||||
class TemporaryDirectory(object):
|
||||
"""Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
|
||||
|
||||
def __enter__(self):
|
||||
self.name = tempfile.mkdtemp()
|
||||
return self.name
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
shutil.rmtree(self.name)
|
||||
|
||||
|
||||
else:
|
||||
TemporaryDirectory = tempfile.TemporaryDirectory
|
||||
unicode = str
|
||||
|
||||
|
||||
def _config_zero_init(config):
|
||||
configs_no_init = copy.deepcopy(config)
|
||||
@ -92,7 +73,7 @@ class ModelTesterMixin:
|
||||
out_2 = outputs[0].numpy()
|
||||
out_2[np.isnan(out_2)] = 0
|
||||
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
model.to(torch_device)
|
||||
@ -238,7 +219,7 @@ class ModelTesterMixin:
|
||||
except RuntimeError:
|
||||
self.fail("Couldn't trace module.")
|
||||
|
||||
with TemporaryDirectory() as tmp_dir_name:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
|
||||
|
||||
try:
|
||||
@ -366,7 +347,7 @@ class ModelTesterMixin:
|
||||
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)), -1: [0]}
|
||||
model.prune_heads(heads_to_prune)
|
||||
|
||||
with TemporaryDirectory() as temp_dir_name:
|
||||
with tempfile.TemporaryDirectory() as temp_dir_name:
|
||||
model.save_pretrained(temp_dir_name)
|
||||
model = model_class.from_pretrained(temp_dir_name)
|
||||
model.to(torch_device)
|
||||
@ -435,7 +416,7 @@ class ModelTesterMixin:
|
||||
self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads)
|
||||
self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads)
|
||||
|
||||
with TemporaryDirectory() as temp_dir_name:
|
||||
with tempfile.TemporaryDirectory() as temp_dir_name:
|
||||
model.save_pretrained(temp_dir_name)
|
||||
model = model_class.from_pretrained(temp_dir_name)
|
||||
model.to(torch_device)
|
||||
|
@ -17,8 +17,6 @@
|
||||
import copy
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
from transformers import is_tf_available, is_torch_available
|
||||
@ -32,23 +30,6 @@ if is_tf_available():
|
||||
|
||||
# from transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
|
||||
class TemporaryDirectory(object):
|
||||
"""Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
|
||||
|
||||
def __enter__(self):
|
||||
self.name = tempfile.mkdtemp()
|
||||
return self.name
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
shutil.rmtree(self.name)
|
||||
|
||||
|
||||
else:
|
||||
TemporaryDirectory = tempfile.TemporaryDirectory
|
||||
unicode = str
|
||||
|
||||
|
||||
def _config_zero_init(config):
|
||||
configs_no_init = copy.deepcopy(config)
|
||||
@ -87,7 +68,7 @@ class TFModelTesterMixin:
|
||||
model = model_class(config)
|
||||
outputs = model(inputs_dict)
|
||||
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
after_outputs = model(inputs_dict)
|
||||
@ -137,7 +118,7 @@ class TFModelTesterMixin:
|
||||
self.assertLessEqual(max_diff, 2e-2)
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
|
||||
torch.save(pt_model.state_dict(), pt_checkpoint_path)
|
||||
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
|
||||
@ -180,7 +161,7 @@ class TFModelTesterMixin:
|
||||
model = model_class(config)
|
||||
|
||||
# Let's load it from the disk to be sure we can use pretrained weights
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
outputs = model(inputs_dict) # build the model
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
|
@ -15,11 +15,11 @@
|
||||
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
from .test_tokenization_common import TemporaryDirectory
|
||||
from .utils import require_torch
|
||||
|
||||
|
||||
@ -50,7 +50,7 @@ def unwrap_and_save_reload_schedule(scheduler, num_steps=10):
|
||||
scheduler.step()
|
||||
lrs.append(scheduler.get_lr())
|
||||
if step == num_steps // 2:
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
file_name = os.path.join(tmpdirname, "schedule.bin")
|
||||
torch.save(scheduler.state_dict(), file_name)
|
||||
|
||||
|
@ -15,33 +15,12 @@
|
||||
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
from io import open
|
||||
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
import cPickle as pickle
|
||||
|
||||
class TemporaryDirectory(object):
|
||||
"""Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
|
||||
|
||||
def __enter__(self):
|
||||
self.name = tempfile.mkdtemp()
|
||||
return self.name
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
shutil.rmtree(self.name)
|
||||
|
||||
|
||||
else:
|
||||
import pickle
|
||||
|
||||
TemporaryDirectory = tempfile.TemporaryDirectory
|
||||
unicode = str
|
||||
|
||||
|
||||
class TokenizerTesterMixin:
|
||||
|
||||
tokenizer_class = None
|
||||
@ -90,7 +69,7 @@ class TokenizerTesterMixin:
|
||||
|
||||
before_tokens = tokenizer.encode("He is very happy, UNwant\u00E9d,running", add_special_tokens=False)
|
||||
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tokenizer.save_pretrained(tmpdirname)
|
||||
tokenizer = self.tokenizer_class.from_pretrained(tmpdirname)
|
||||
|
||||
@ -108,7 +87,7 @@ class TokenizerTesterMixin:
|
||||
text = "Munich and Berlin are nice cities"
|
||||
subwords = tokenizer.tokenize(text)
|
||||
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
|
||||
filename = os.path.join(tmpdirname, "tokenizer.bin")
|
||||
with open(filename, "wb") as handle:
|
||||
@ -246,7 +225,7 @@ class TokenizerTesterMixin:
|
||||
self.assertEqual(text_2, output_text)
|
||||
|
||||
self.assertNotEqual(len(tokens_2), 0)
|
||||
self.assertIsInstance(text_2, (str, unicode))
|
||||
self.assertIsInstance(text_2, str)
|
||||
|
||||
def test_encode_decode_with_spaces(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
@ -268,9 +247,6 @@ class TokenizerTesterMixin:
|
||||
self.assertListEqual(weights_list, weights_list_2)
|
||||
|
||||
def test_mask_output(self):
|
||||
if sys.version_info <= (3, 0):
|
||||
return
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
if tokenizer.build_inputs_with_special_tokens.__qualname__.split(".")[0] != "PreTrainedTokenizer":
|
||||
|
Loading…
Reference in New Issue
Block a user