From d32585a304107cb9f42ccb0e1278405aa3eb6c9c Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Tue, 21 Apr 2020 10:49:00 -0400 Subject: [PATCH] Fix Torch.hub + Integration test --- .github/workflows/github-torch-hub.yml | 32 ++++++++++++++++++++ hubconf.py | 6 ++-- src/transformers/modeling_albert.py | 7 ++--- src/transformers/modeling_electra.py | 7 ++--- src/transformers/tokenization_camembert.py | 3 +- src/transformers/tokenization_xlm_roberta.py | 3 +- 6 files changed, 43 insertions(+), 15 deletions(-) create mode 100644 .github/workflows/github-torch-hub.yml diff --git a/.github/workflows/github-torch-hub.yml b/.github/workflows/github-torch-hub.yml new file mode 100644 index 00000000000..a0ee5e4655b --- /dev/null +++ b/.github/workflows/github-torch-hub.yml @@ -0,0 +1,32 @@ +name: Torch hub integration + +on: + push: + branches: + - "*" + +jobs: + torch_hub_integration: + runs-on: ubuntu-latest + steps: + # no checkout necessary here. + - name: Extract branch name + run: echo "::set-env name=BRANCH::${GITHUB_REF#refs/heads/}" + - name: Check branch name + run: echo $BRANCH + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: 3.7 + - name: Install dependencies + run: | + pip install torch + pip install numpy tokenizers boto3 filelock requests tqdm regex sentencepiece sacremoses + + - name: Torch hub list + run: | + python -c "import torch; print(torch.hub.list('huggingface/transformers:$BRANCH'))" + + - name: Torch hub help + run: | + python -c "import torch; print(torch.hub.help('huggingface/transformers:$BRANCH', 'modelForSequenceClassification'))" diff --git a/hubconf.py b/hubconf.py index 4e5c1b4b01d..b473d41c14b 100644 --- a/hubconf.py +++ b/hubconf.py @@ -1,4 +1,4 @@ -from transformers import ( +from src.transformers import ( AutoConfig, AutoModel, AutoModelForQuestionAnswering, @@ -6,10 +6,10 @@ from transformers import ( AutoModelWithLMHead, AutoTokenizer, ) -from transformers.file_utils import add_start_docstrings +from src.transformers.file_utils import add_start_docstrings -dependencies = ["torch", "tqdm", "boto3", "requests", "regex", "sentencepiece", "sacremoses"] +dependencies = ["torch", "numpy", "tokenizers", "boto3", "filelock", "requests", "tqdm", "regex", "sentencepiece", "sacremoses"] @add_start_docstrings(AutoConfig.__doc__) diff --git a/src/transformers/modeling_albert.py b/src/transformers/modeling_albert.py index a231f024392..d7d678bde93 100644 --- a/src/transformers/modeling_albert.py +++ b/src/transformers/modeling_albert.py @@ -22,11 +22,10 @@ import torch import torch.nn as nn from torch.nn import CrossEntropyLoss, MSELoss -from transformers.configuration_albert import AlbertConfig -from transformers.modeling_bert import ACT2FN, BertEmbeddings, BertSelfAttention, prune_linear_layer -from transformers.modeling_utils import PreTrainedModel - +from .configuration_albert import AlbertConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable +from .modeling_bert import ACT2FN, BertEmbeddings, BertSelfAttention, prune_linear_layer +from .modeling_utils import PreTrainedModel logger = logging.getLogger(__name__) diff --git a/src/transformers/modeling_electra.py b/src/transformers/modeling_electra.py index ffe66b073a2..0626f5eb33a 100644 --- a/src/transformers/modeling_electra.py +++ b/src/transformers/modeling_electra.py @@ -4,10 +4,9 @@ import os import torch import torch.nn as nn -from transformers import ElectraConfig, add_start_docstrings -from transformers.activations import get_activation - -from .file_utils import add_start_docstrings_to_callable +from .activations import get_activation +from .configuration_electra import ElectraConfig +from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_bert import BertEmbeddings, BertEncoder, BertLayerNorm, BertPreTrainedModel diff --git a/src/transformers/tokenization_camembert.py b/src/transformers/tokenization_camembert.py index 0179d020bf0..4a1069f737c 100644 --- a/src/transformers/tokenization_camembert.py +++ b/src/transformers/tokenization_camembert.py @@ -22,8 +22,7 @@ from typing import List, Optional import sentencepiece as spm -from transformers.tokenization_utils import PreTrainedTokenizer - +from .tokenization_utils import PreTrainedTokenizer from .tokenization_xlnet import SPIECE_UNDERLINE diff --git a/src/transformers/tokenization_xlm_roberta.py b/src/transformers/tokenization_xlm_roberta.py index f26634410ec..f5331ff166f 100644 --- a/src/transformers/tokenization_xlm_roberta.py +++ b/src/transformers/tokenization_xlm_roberta.py @@ -20,8 +20,7 @@ import os from shutil import copyfile from typing import List, Optional -from transformers.tokenization_utils import PreTrainedTokenizer - +from .tokenization_utils import PreTrainedTokenizer from .tokenization_xlnet import SPIECE_UNDERLINE