add base model classes to bart subclassed models (#9230)

* add base model classes to  bart subclassed models

* add doc
This commit is contained in:
Suraj Patil 2020-12-21 19:56:46 +05:30 committed by GitHub
parent 08abdabda1
commit f4432b7e01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 134 additions and 17 deletions

View File

@ -100,6 +100,15 @@ BlenderbotSmallTokenizer
:members:
BlenderbotModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
See :obj:`transformers.BartModel` for arguments to `forward` and `generate`
.. autoclass:: transformers.BlenderbotModel
:members:
BlenderbotForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -97,6 +97,13 @@ MBartTokenizerFast
:members:
MBartModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.MBartModel
:members:
MBartForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -119,6 +119,12 @@ PegasusTokenizerFast
:members:
PegasusModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.PegasusModel
PegasusForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -406,7 +406,11 @@ if is_torch_available():
BertGenerationEncoder,
load_tf_weights_in_bert_generation,
)
from .models.blenderbot import BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST, BlenderbotForConditionalGeneration
from .models.blenderbot import (
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST,
BlenderbotForConditionalGeneration,
BlenderbotModel,
)
from .models.camembert import (
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
CamembertForCausalLM,
@ -522,7 +526,7 @@ if is_torch_available():
LxmertXLayer,
)
from .models.marian import MarianMTModel
from .models.mbart import MBartForConditionalGeneration
from .models.mbart import MBartForConditionalGeneration, MBartModel
from .models.mmbt import MMBTForClassification, MMBTModel, ModalEmbeddings
from .models.mobilebert import (
MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
@ -559,7 +563,7 @@ if is_torch_available():
OpenAIGPTPreTrainedModel,
load_tf_weights_in_openai_gpt,
)
from .models.pegasus import PegasusForConditionalGeneration
from .models.pegasus import PegasusForConditionalGeneration, PegasusModel
from .models.prophetnet import (
PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST,
ProphetNetDecoder,

View File

@ -50,7 +50,7 @@ from ..bert.modeling_bert import (
BertModel,
)
from ..bert_generation.modeling_bert_generation import BertGenerationDecoder, BertGenerationEncoder
from ..blenderbot.modeling_blenderbot import BlenderbotForConditionalGeneration
from ..blenderbot.modeling_blenderbot import BlenderbotForConditionalGeneration, BlenderbotModel
from ..camembert.modeling_camembert import (
CamembertForCausalLM,
CamembertForMaskedLM,
@ -111,7 +111,7 @@ from ..longformer.modeling_longformer import (
)
from ..lxmert.modeling_lxmert import LxmertForPreTraining, LxmertForQuestionAnswering, LxmertModel
from ..marian.modeling_marian import MarianMTModel
from ..mbart.modeling_mbart import MBartForConditionalGeneration
from ..mbart.modeling_mbart import MBartForConditionalGeneration, MBartModel
from ..mobilebert.modeling_mobilebert import (
MobileBertForMaskedLM,
MobileBertForMultipleChoice,
@ -132,7 +132,7 @@ from ..mpnet.modeling_mpnet import (
)
from ..mt5.modeling_mt5 import MT5ForConditionalGeneration, MT5Model
from ..openai.modeling_openai import OpenAIGPTForSequenceClassification, OpenAIGPTLMHeadModel, OpenAIGPTModel
from ..pegasus.modeling_pegasus import PegasusForConditionalGeneration
from ..pegasus.modeling_pegasus import PegasusForConditionalGeneration, PegasusModel
from ..prophetnet.modeling_prophetnet import ProphetNetForCausalLM, ProphetNetForConditionalGeneration, ProphetNetModel
from ..rag.modeling_rag import ( # noqa: F401 - need to import all RagModels to be in globals() function
RagModel,
@ -255,6 +255,10 @@ MODEL_MAPPING = OrderedDict(
(RetriBertConfig, RetriBertModel),
(MT5Config, MT5Model),
(T5Config, T5Model),
(PegasusConfig, PegasusModel),
(MarianConfig, MarianMTModel),
(MBartConfig, MBartModel),
(BlenderbotConfig, BlenderbotModel),
(DistilBertConfig, DistilBertModel),
(AlbertConfig, AlbertModel),
(CamembertConfig, CamembertModel),

View File

@ -22,7 +22,11 @@ from .tokenization_blenderbot import BlenderbotSmallTokenizer, BlenderbotTokeniz
if is_torch_available():
from .modeling_blenderbot import BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST, BlenderbotForConditionalGeneration
from .modeling_blenderbot import (
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST,
BlenderbotForConditionalGeneration,
BlenderbotModel,
)
if is_tf_available():
from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration

View File

@ -19,7 +19,7 @@
import torch
from ...file_utils import add_start_docstrings
from ..bart.modeling_bart import BartForConditionalGeneration
from ..bart.modeling_bart import BartForConditionalGeneration, BartModel
from .configuration_blenderbot import BlenderbotConfig
@ -39,7 +39,20 @@ BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST = ["facebook/blenderbot-3B", "facebook/
@add_start_docstrings(
"The BART Model with a language modeling head. Can be used for summarization.", BLENDER_START_DOCSTRING
"The bare BlenderBot Model transformer outputting raw hidden-states without any specific head on top.",
BLENDER_START_DOCSTRING,
)
class BlenderbotModel(BartModel):
r"""
This class overrides :class:`~transformers.BartModel`. Please check the superclass for the appropriate
documentation alongside usage examples.
"""
config_class = BlenderbotConfig
@add_start_docstrings(
"The BlenderBot Model with a language modeling head. Can be used for summarization.", BLENDER_START_DOCSTRING
)
class BlenderbotForConditionalGeneration(BartForConditionalGeneration):
"""

View File

@ -27,7 +27,7 @@ if is_tokenizers_available():
from .tokenization_mbart_fast import MBartTokenizerFast
if is_torch_available():
from .modeling_mbart import MBartForConditionalGeneration
from .modeling_mbart import MBartForConditionalGeneration, MBartModel
if is_tf_available():
from .modeling_tf_mbart import TFMBartForConditionalGeneration

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ..bart.modeling_bart import BartForConditionalGeneration
from ..bart.modeling_bart import BartForConditionalGeneration, BartModel
from .configuration_mbart import MBartConfig
@ -26,6 +26,23 @@ MBART_PRETRAINED_MODEL_ARCHIVE_LIST = [
]
class MBartModel(BartModel):
r"""
This class overrides :class:`~transformers.BartModel`. Please check the superclass for the appropriate
documentation alongside usage examples.
"""
config_class = MBartConfig
_keys_to_ignore_on_load_missing = [
"encoder.embed_positions.weight",
"decoder.embed_positions.weight",
]
_keys_to_ignore_on_save = [
"encoder.embed_positions.weight",
"decoder.embed_positions.weight",
]
class MBartForConditionalGeneration(BartForConditionalGeneration):
r"""
This class overrides :class:`~transformers.BartForConditionalGeneration`. Please check the superclass for the

View File

@ -27,7 +27,7 @@ if is_tokenizers_available():
from .tokenization_pegasus_fast import PegasusTokenizerFast
if is_torch_available():
from .modeling_pegasus import PegasusForConditionalGeneration
from .modeling_pegasus import PegasusForConditionalGeneration, PegasusModel
if is_tf_available():
from .modeling_tf_pegasus import TFPegasusForConditionalGeneration

View File

@ -16,10 +16,34 @@
from ...file_utils import add_start_docstrings
from ..bart.modeling_bart import BART_START_DOCSTRING, BartForConditionalGeneration
from ..bart.modeling_bart import BART_START_DOCSTRING, BartForConditionalGeneration, BartModel
from .configuration_pegasus import PegasusConfig
@add_start_docstrings(
"The bare Pegasus Model transformer outputting raw hidden-states without any specific head on top.",
BART_START_DOCSTRING,
)
class PegasusModel(BartModel):
r"""
This class overrides :class:`~transformers.BartModel`. Please check the superclass for the appropriate
documentation alongside usage examples.
"""
config_class = PegasusConfig
_keys_to_ignore_on_load_missing = [
r"final_logits_bias",
r"encoder\.version",
r"decoder\.version",
"encoder.embed_positions",
"decoder.embed_positions",
]
_keys_to_ignore_on_save = [
"encoder.embed_positions.weight",
"decoder.embed_positions.weight",
]
@add_start_docstrings("The Pegasus Model for summarization ", BART_START_DOCSTRING)
class PegasusForConditionalGeneration(BartForConditionalGeneration):
r"""

View File

@ -600,6 +600,15 @@ class BlenderbotForConditionalGeneration:
requires_pytorch(self)
class BlenderbotModel:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
@ -1297,6 +1306,15 @@ class MBartForConditionalGeneration:
requires_pytorch(self)
class MBartModel:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class MMBTForClassification:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@ -1560,6 +1578,15 @@ class PegasusForConditionalGeneration:
requires_pytorch(self)
class PegasusModel:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST = None

View File

@ -32,6 +32,7 @@ if is_torch_available():
AutoTokenizer,
BlenderbotConfig,
BlenderbotForConditionalGeneration,
BlenderbotModel,
BlenderbotSmallTokenizer,
BlenderbotTokenizer,
)
@ -90,7 +91,7 @@ class BlenderbotModelTester:
class BlenderbotTesterMixin(ModelTesterMixin, unittest.TestCase):
if is_torch_available():
all_generative_model_classes = (BlenderbotForConditionalGeneration,)
all_model_classes = (BlenderbotForConditionalGeneration,)
all_model_classes = (BlenderbotForConditionalGeneration, BlenderbotModel)
else:
all_generative_model_classes = ()
all_model_classes = ()

View File

@ -30,6 +30,7 @@ if is_torch_available():
BatchEncoding,
MBartConfig,
MBartForConditionalGeneration,
MBartModel,
)
@ -59,7 +60,7 @@ class ModelTester:
@require_torch
class SelectiveCommonTest(unittest.TestCase):
all_model_classes = (MBartForConditionalGeneration,) if is_torch_available() else ()
all_model_classes = (MBartForConditionalGeneration, MBartModel) if is_torch_available() else ()
test_save_load__keys_to_ignore_on_save = ModelTesterMixin.test_save_load__keys_to_ignore_on_save

View File

@ -26,7 +26,7 @@ from .test_modeling_mbart import AbstractSeq2SeqIntegrationTest
if is_torch_available():
from transformers import AutoModelForSeq2SeqLM, PegasusConfig, PegasusForConditionalGeneration
from transformers import AutoModelForSeq2SeqLM, PegasusConfig, PegasusForConditionalGeneration, PegasusModel
XSUM_ENTRY_LONGER = """ The London trio are up for best UK act and best album, as well as getting two nominations in the best song category."We got told like this morning 'Oh I think you're nominated'", said Dappy."And I was like 'Oh yeah, which one?' And now we've got nominated for four awards. I mean, wow!"Bandmate Fazer added: "We thought it's best of us to come down and mingle with everyone and say hello to the cameras. And now we find we've got four nominations."The band have two shots at the best song prize, getting the nod for their Tynchy Stryder collaboration Number One, and single Strong Again.Their album Uncle B will also go up against records by the likes of Beyonce and Kanye West.N-Dubz picked up the best newcomer Mobo in 2007, but female member Tulisa said they wouldn't be too disappointed if they didn't win this time around."At the end of the day we're grateful to be where we are in our careers."If it don't happen then it don't happen - live to fight another day and keep on making albums and hits for the fans."Dappy also revealed they could be performing live several times on the night.The group will be doing Number One and also a possible rendition of the War Child single, I Got Soul.The charity song is a re-working of The Killers' All These Things That I've Done and is set to feature artists like Chipmunk, Ironik and Pixie Lott.This year's Mobos will be held outside of London for the first time, in Glasgow on 30 September.N-Dubz said they were looking forward to performing for their Scottish fans and boasted about their recent shows north of the border."We just done Edinburgh the other day," said Dappy."We smashed up an N-Dubz show over there. We done Aberdeen about three or four months ago - we smashed up that show over there! Everywhere we go we smash it up!" """
@ -55,7 +55,7 @@ class ModelTester:
@require_torch
class SelectiveCommonTest(unittest.TestCase):
all_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else ()
all_model_classes = (PegasusForConditionalGeneration, PegasusModel) if is_torch_available() else ()
test_save_load__keys_to_ignore_on_save = ModelTesterMixin.test_save_load__keys_to_ignore_on_save