mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
fix tests
This commit is contained in:
parent
78863f6b36
commit
a6bcfb8015
@ -56,8 +56,6 @@ from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CO
|
|||||||
|
|
||||||
# Modeling
|
# Modeling
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
logger.info("PyTorch version {} available.".format(torch.__version__))
|
|
||||||
|
|
||||||
from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D)
|
from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D)
|
||||||
from .modeling_auto import (AutoModel, AutoModelForSequenceClassification, AutoModelForQuestionAnswering,
|
from .modeling_auto import (AutoModel, AutoModelForSequenceClassification, AutoModelForQuestionAnswering,
|
||||||
AutoModelWithLMHead)
|
AutoModelWithLMHead)
|
||||||
@ -96,8 +94,6 @@ if is_torch_available():
|
|||||||
|
|
||||||
# TensorFlow
|
# TensorFlow
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
logger.info("TensorFlow version {} available.".format(tf.__version__))
|
|
||||||
|
|
||||||
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary
|
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary
|
||||||
from .modeling_tf_auto import (TFAutoModel, TFAutoModelForSequenceClassification, TFAutoModelForQuestionAnswering,
|
from .modeling_tf_auto import (TFAutoModel, TFAutoModelForSequenceClassification, TFAutoModelForQuestionAnswering,
|
||||||
TFAutoModelWithLMHead)
|
TFAutoModelWithLMHead)
|
||||||
|
@ -23,16 +23,20 @@ from botocore.exceptions import ClientError
|
|||||||
import requests
|
import requests
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
assert int(tf.__version__[0]) >= 2
|
assert int(tf.__version__[0]) >= 2
|
||||||
_tf_available = True # pylint: disable=invalid-name
|
_tf_available = True # pylint: disable=invalid-name
|
||||||
|
logger.info("TensorFlow version {} available.".format(tf.__version__))
|
||||||
except (ImportError, AssertionError):
|
except (ImportError, AssertionError):
|
||||||
_tf_available = False # pylint: disable=invalid-name
|
_tf_available = False # pylint: disable=invalid-name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
_torch_available = True # pylint: disable=invalid-name
|
_torch_available = True # pylint: disable=invalid-name
|
||||||
|
logger.info("PyTorch version {} available.".format(torch.__version__))
|
||||||
except ImportError:
|
except ImportError:
|
||||||
_torch_available = False # pylint: disable=invalid-name
|
_torch_available = False # pylint: disable=invalid-name
|
||||||
|
|
||||||
@ -67,8 +71,6 @@ TF2_WEIGHTS_NAME = 'tf_model.h5'
|
|||||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||||
CONFIG_NAME = "config.json"
|
CONFIG_NAME = "config.json"
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
|
||||||
|
|
||||||
def is_torch_available():
|
def is_torch_available():
|
||||||
return _torch_available
|
return _torch_available
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ from .file_utils import cached_path, is_tf_available, is_torch_available
|
|||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
if is_torch_available()
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
Loading…
Reference in New Issue
Block a user