mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
fix test when tf is not here
This commit is contained in:
parent
59fe641b8b
commit
ad0ab9afe9
@ -11,6 +11,7 @@ jobs:
|
|||||||
- run: sudo pip install --progress-bar off .
|
- run: sudo pip install --progress-bar off .
|
||||||
- run: sudo pip install pytest codecov pytest-cov
|
- run: sudo pip install pytest codecov pytest-cov
|
||||||
- run: sudo pip install tensorboardX scikit-learn
|
- run: sudo pip install tensorboardX scikit-learn
|
||||||
|
- run: sudo pip install tensorflow==2.0.0-rc0
|
||||||
- run: python -m pytest -sv ./pytorch_transformers/tests/ --cov
|
- run: python -m pytest -sv ./pytorch_transformers/tests/ --cov
|
||||||
- run: python -m pytest -sv ./examples/
|
- run: python -m pytest -sv ./examples/
|
||||||
- run: codecov
|
- run: codecov
|
||||||
@ -24,6 +25,7 @@ jobs:
|
|||||||
- checkout
|
- checkout
|
||||||
- run: sudo pip install --progress-bar off .
|
- run: sudo pip install --progress-bar off .
|
||||||
- run: sudo pip install pytest codecov pytest-cov
|
- run: sudo pip install pytest codecov pytest-cov
|
||||||
|
- run: sudo pip install tensorflow==2.0.0-rc0
|
||||||
- run: python -m pytest -sv ./pytorch_transformers/tests/ --cov
|
- run: python -m pytest -sv ./pytorch_transformers/tests/ --cov
|
||||||
- run: codecov
|
- run: codecov
|
||||||
deploy_doc:
|
deploy_doc:
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
__version__ = "1.2.0"
|
__version__ = "1.2.0"
|
||||||
|
|
||||||
# Work around to update TensorFlow's absl.logging threshold which alters the
|
# Work around to update TensorFlow's absl.logging threshold which alters the
|
||||||
# default Python logging output behavior when present.
|
# default Python logging output behavior when present.
|
||||||
# see: https://github.com/abseil/abseil-py/issues/99
|
# see: https://github.com/abseil/abseil-py/issues/99
|
||||||
@ -11,6 +12,10 @@ try:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
# Tokenizer
|
# Tokenizer
|
||||||
from .tokenization_utils import (PreTrainedTokenizer)
|
from .tokenization_utils import (PreTrainedTokenizer)
|
||||||
from .tokenization_auto import AutoTokenizer
|
from .tokenization_auto import AutoTokenizer
|
||||||
@ -36,38 +41,63 @@ from .configuration_roberta import RobertaConfig, ROBERTA_PRETRAINED_CONFIG_ARCH
|
|||||||
from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
|
||||||
# Modeling
|
# Modeling
|
||||||
from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D)
|
try:
|
||||||
from .modeling_auto import (AutoModel, AutoModelForSequenceClassification, AutoModelForQuestionAnswering,
|
import torch
|
||||||
AutoModelWithLMHead)
|
torch_available = True # pylint: disable=invalid-name
|
||||||
|
except ImportError:
|
||||||
|
torch_available = False # pylint: disable=invalid-name
|
||||||
|
|
||||||
from .modeling_bert import (BertPreTrainedModel, BertModel, BertForPreTraining,
|
if torch_available:
|
||||||
BertForMaskedLM, BertForNextSentencePrediction,
|
logger.info("PyTorch version {} available.".format(torch.__version__))
|
||||||
BertForSequenceClassification, BertForMultipleChoice,
|
|
||||||
BertForTokenClassification, BertForQuestionAnswering,
|
from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D)
|
||||||
load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
from .modeling_auto import (AutoModel, AutoModelForSequenceClassification, AutoModelForQuestionAnswering,
|
||||||
from .modeling_openai import (OpenAIGPTPreTrainedModel, OpenAIGPTModel,
|
AutoModelWithLMHead)
|
||||||
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
|
|
||||||
load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
from .modeling_bert import (BertPreTrainedModel, BertModel, BertForPreTraining,
|
||||||
from .modeling_transfo_xl import (TransfoXLPreTrainedModel, TransfoXLModel, TransfoXLLMHeadModel,
|
BertForMaskedLM, BertForNextSentencePrediction,
|
||||||
load_tf_weights_in_transfo_xl, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
BertForSequenceClassification, BertForMultipleChoice,
|
||||||
from .modeling_gpt2 import (GPT2PreTrainedModel, GPT2Model,
|
BertForTokenClassification, BertForQuestionAnswering,
|
||||||
GPT2LMHeadModel, GPT2DoubleHeadsModel,
|
load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
load_tf_weights_in_gpt2, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
|
from .modeling_openai import (OpenAIGPTPreTrainedModel, OpenAIGPTModel,
|
||||||
from .modeling_xlnet import (XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel,
|
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
|
||||||
XLNetForSequenceClassification, XLNetForQuestionAnswering,
|
load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
load_tf_weights_in_xlnet, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
|
from .modeling_transfo_xl import (TransfoXLPreTrainedModel, TransfoXLModel, TransfoXLLMHeadModel,
|
||||||
from .modeling_xlm import (XLMPreTrainedModel , XLMModel,
|
load_tf_weights_in_transfo_xl, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
XLMWithLMHeadModel, XLMForSequenceClassification,
|
from .modeling_gpt2 import (GPT2PreTrainedModel, GPT2Model,
|
||||||
XLMForQuestionAnswering, XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
GPT2LMHeadModel, GPT2DoubleHeadsModel,
|
||||||
from .modeling_roberta import (RobertaForMaskedLM, RobertaModel, RobertaForSequenceClassification,
|
load_tf_weights_in_gpt2, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
|
from .modeling_xlnet import (XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel,
|
||||||
from .modeling_distilbert import (DistilBertForMaskedLM, DistilBertModel,
|
XLNetForSequenceClassification, XLNetForQuestionAnswering,
|
||||||
DistilBertForSequenceClassification, DistilBertForQuestionAnswering,
|
load_tf_weights_in_xlnet, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
from .modeling_xlm import (XLMPreTrainedModel , XLMModel,
|
||||||
|
XLMWithLMHeadModel, XLMForSequenceClassification,
|
||||||
|
XLMForQuestionAnswering, XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
|
from .modeling_roberta import (RobertaForMaskedLM, RobertaModel, RobertaForSequenceClassification,
|
||||||
|
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
|
from .modeling_distilbert import (DistilBertForMaskedLM, DistilBertModel,
|
||||||
|
DistilBertForSequenceClassification, DistilBertForQuestionAnswering,
|
||||||
|
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
|
|
||||||
|
# Optimization
|
||||||
|
from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule,
|
||||||
|
WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
|
||||||
|
|
||||||
|
|
||||||
|
# TensorFlow
|
||||||
|
try:
|
||||||
|
import tensorflow as tf
|
||||||
|
tf_available = True # pylint: disable=invalid-name
|
||||||
|
except ImportError:
|
||||||
|
tf_available = False # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
if tf_available:
|
||||||
|
logger.info("TensorFlow version {} available.".format(tf.__version__))
|
||||||
|
|
||||||
|
from .modeling_tf_utils import TFPreTrainedModel
|
||||||
|
from .modeling_tf_bert import (TFBertPreTrainedModel, TFBertModel, TFBertForPreTraining,
|
||||||
|
TFBertForMaskedLM, TFBertForNextSentencePrediction, load_pt_weights_in_bert)
|
||||||
|
|
||||||
# Optimization
|
|
||||||
from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule,
|
|
||||||
WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
|
|
||||||
|
|
||||||
# Files and general utilities
|
# Files and general utilities
|
||||||
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
|
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
|
||||||
|
@ -19,15 +19,19 @@ from __future__ import print_function
|
|||||||
import unittest
|
import unittest
|
||||||
import shutil
|
import shutil
|
||||||
import pytest
|
import pytest
|
||||||
|
import sys
|
||||||
import tensorflow as tf
|
|
||||||
|
|
||||||
from pytorch_transformers import (BertConfig)
|
|
||||||
from pytorch_transformers.modeling_tf_bert import TFBertModel, TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
|
||||||
|
|
||||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||||
from .configuration_common_test import ConfigTester
|
from .configuration_common_test import ConfigTester
|
||||||
|
|
||||||
|
try:
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from pytorch_transformers import (BertConfig)
|
||||||
|
from pytorch_transformers.modeling_tf_bert import TFBertModel, TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
|
class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||||
|
|
||||||
@ -283,39 +287,48 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
|
|||||||
def test_config(self):
|
def test_config(self):
|
||||||
self.config_tester.run_common_tests()
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||||
def test_bert_model(self):
|
def test_bert_model(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_bert_model(*config_and_inputs)
|
self.model_tester.create_and_check_bert_model(*config_and_inputs)
|
||||||
|
|
||||||
|
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||||
def test_for_masked_lm(self):
|
def test_for_masked_lm(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_bert_for_masked_lm(*config_and_inputs)
|
self.model_tester.create_and_check_bert_for_masked_lm(*config_and_inputs)
|
||||||
|
|
||||||
|
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||||
def test_for_multiple_choice(self):
|
def test_for_multiple_choice(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_bert_for_multiple_choice(*config_and_inputs)
|
self.model_tester.create_and_check_bert_for_multiple_choice(*config_and_inputs)
|
||||||
|
|
||||||
|
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||||
def test_for_next_sequence_prediction(self):
|
def test_for_next_sequence_prediction(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_bert_for_next_sequence_prediction(*config_and_inputs)
|
self.model_tester.create_and_check_bert_for_next_sequence_prediction(*config_and_inputs)
|
||||||
|
|
||||||
|
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||||
def test_for_pretraining(self):
|
def test_for_pretraining(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_bert_for_pretraining(*config_and_inputs)
|
self.model_tester.create_and_check_bert_for_pretraining(*config_and_inputs)
|
||||||
|
|
||||||
|
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||||
def test_for_question_answering(self):
|
def test_for_question_answering(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_bert_for_question_answering(*config_and_inputs)
|
self.model_tester.create_and_check_bert_for_question_answering(*config_and_inputs)
|
||||||
|
|
||||||
|
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||||
def test_for_sequence_classification(self):
|
def test_for_sequence_classification(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_bert_for_sequence_classification(*config_and_inputs)
|
self.model_tester.create_and_check_bert_for_sequence_classification(*config_and_inputs)
|
||||||
|
|
||||||
|
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||||
def test_for_token_classification(self):
|
def test_for_token_classification(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_bert_for_token_classification(*config_and_inputs)
|
self.model_tester.create_and_check_bert_for_token_classification(*config_and_inputs)
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
cache_dir = "/tmp/pytorch_transformers_test/"
|
cache_dir = "/tmp/pytorch_transformers_test/"
|
||||||
for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||||
@ -325,3 +338,4 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
@ -12,24 +12,25 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import, division, print_function
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import random
|
import random
|
||||||
|
import shutil
|
||||||
|
import unittest
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import unittest
|
import pytest
|
||||||
import logging
|
import sys
|
||||||
|
|
||||||
import tensorflow as tf
|
try:
|
||||||
|
import tensorflow as tf
|
||||||
from pytorch_transformers import TFPreTrainedModel
|
from pytorch_transformers import TFPreTrainedModel
|
||||||
# from pytorch_transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
# from pytorch_transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _config_zero_init(config):
|
def _config_zero_init(config):
|
||||||
@ -49,6 +50,7 @@ class TFCommonTestCases:
|
|||||||
test_pruning = True
|
test_pruning = True
|
||||||
test_resize_embeddings = True
|
test_resize_embeddings = True
|
||||||
|
|
||||||
|
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||||
def test_initialization(self):
|
def test_initialization(self):
|
||||||
pass
|
pass
|
||||||
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@ -62,6 +64,7 @@ class TFCommonTestCases:
|
|||||||
# msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
|
# msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||||
def test_attention_outputs(self):
|
def test_attention_outputs(self):
|
||||||
pass
|
pass
|
||||||
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@ -102,6 +105,7 @@ class TFCommonTestCases:
|
|||||||
# self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
|
# self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||||
def test_headmasking(self):
|
def test_headmasking(self):
|
||||||
pass
|
pass
|
||||||
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@ -149,6 +153,7 @@ class TFCommonTestCases:
|
|||||||
# attentions[-1][..., -1, :, :].flatten().sum().item(), 0.0)
|
# attentions[-1][..., -1, :, :].flatten().sum().item(), 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||||
def test_head_pruning(self):
|
def test_head_pruning(self):
|
||||||
pass
|
pass
|
||||||
# if not self.test_pruning:
|
# if not self.test_pruning:
|
||||||
@ -176,6 +181,7 @@ class TFCommonTestCases:
|
|||||||
# attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
|
# attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||||
def test_hidden_states_output(self):
|
def test_hidden_states_output(self):
|
||||||
pass
|
pass
|
||||||
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@ -195,6 +201,7 @@ class TFCommonTestCases:
|
|||||||
# [self.model_tester.seq_length, self.model_tester.hidden_size])
|
# [self.model_tester.seq_length, self.model_tester.hidden_size])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||||
def test_resize_tokens_embeddings(self):
|
def test_resize_tokens_embeddings(self):
|
||||||
pass
|
pass
|
||||||
# original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
# original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@ -231,6 +238,7 @@ class TFCommonTestCases:
|
|||||||
# self.assertTrue(models_equal)
|
# self.assertTrue(models_equal)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||||
def test_tie_model_weights(self):
|
def test_tie_model_weights(self):
|
||||||
pass
|
pass
|
||||||
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@ -282,6 +290,7 @@ def ids_tensor(shape, vocab_size, rng=None, name=None):
|
|||||||
|
|
||||||
|
|
||||||
class TFModelUtilsTest(unittest.TestCase):
|
class TFModelUtilsTest(unittest.TestCase):
|
||||||
|
@pytest.mark.skipif('tf' not in sys.modules, reason="requires TensorFlow")
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
pass
|
pass
|
||||||
# logging.basicConfig(level=logging.INFO)
|
# logging.basicConfig(level=logging.INFO)
|
||||||
|
Loading…
Reference in New Issue
Block a user