diff --git a/docs/source/model_doc/blenderbot.rst b/docs/source/model_doc/blenderbot.rst index ddceeb81c1b..df43c90ef07 100644 --- a/docs/source/model_doc/blenderbot.rst +++ b/docs/source/model_doc/blenderbot.rst @@ -100,6 +100,15 @@ BlenderbotSmallTokenizer :members: +BlenderbotModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +See :obj:`transformers.BartModel` for arguments to `forward` and `generate` + +.. autoclass:: transformers.BlenderbotModel + :members: + + BlenderbotForConditionalGeneration ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/model_doc/mbart.rst b/docs/source/model_doc/mbart.rst index eb9b9798024..4ac391255eb 100644 --- a/docs/source/model_doc/mbart.rst +++ b/docs/source/model_doc/mbart.rst @@ -97,6 +97,13 @@ MBartTokenizerFast :members: +MBartModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.MBartModel + :members: + + MBartForConditionalGeneration ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/model_doc/pegasus.rst b/docs/source/model_doc/pegasus.rst index 42b3e5ea57b..3fab320ebcb 100644 --- a/docs/source/model_doc/pegasus.rst +++ b/docs/source/model_doc/pegasus.rst @@ -119,6 +119,12 @@ PegasusTokenizerFast :members: +PegasusModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.PegasusModel + + PegasusForConditionalGeneration ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 4586fe5363f..580318abaa2 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -406,7 +406,11 @@ if is_torch_available(): BertGenerationEncoder, load_tf_weights_in_bert_generation, ) - from .models.blenderbot import BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST, BlenderbotForConditionalGeneration + from .models.blenderbot import ( + BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST, + BlenderbotForConditionalGeneration, + BlenderbotModel, + ) from .models.camembert import ( CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST, CamembertForCausalLM, @@ -522,7 +526,7 @@ if is_torch_available(): LxmertXLayer, ) from .models.marian import MarianMTModel - from .models.mbart import MBartForConditionalGeneration + from .models.mbart import MBartForConditionalGeneration, MBartModel from .models.mmbt import MMBTForClassification, MMBTModel, ModalEmbeddings from .models.mobilebert import ( MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST, @@ -559,7 +563,7 @@ if is_torch_available(): OpenAIGPTPreTrainedModel, load_tf_weights_in_openai_gpt, ) - from .models.pegasus import PegasusForConditionalGeneration + from .models.pegasus import PegasusForConditionalGeneration, PegasusModel from .models.prophetnet import ( PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST, ProphetNetDecoder, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 4b9141d0245..3fc5c702e7d 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -50,7 +50,7 @@ from ..bert.modeling_bert import ( BertModel, ) from ..bert_generation.modeling_bert_generation import BertGenerationDecoder, BertGenerationEncoder -from ..blenderbot.modeling_blenderbot import BlenderbotForConditionalGeneration +from ..blenderbot.modeling_blenderbot import BlenderbotForConditionalGeneration, BlenderbotModel from ..camembert.modeling_camembert import ( CamembertForCausalLM, CamembertForMaskedLM, @@ -111,7 +111,7 @@ from ..longformer.modeling_longformer import ( ) from ..lxmert.modeling_lxmert import LxmertForPreTraining, LxmertForQuestionAnswering, LxmertModel from ..marian.modeling_marian import MarianMTModel -from ..mbart.modeling_mbart import MBartForConditionalGeneration +from ..mbart.modeling_mbart import MBartForConditionalGeneration, MBartModel from ..mobilebert.modeling_mobilebert import ( MobileBertForMaskedLM, MobileBertForMultipleChoice, @@ -132,7 +132,7 @@ from ..mpnet.modeling_mpnet import ( ) from ..mt5.modeling_mt5 import MT5ForConditionalGeneration, MT5Model from ..openai.modeling_openai import OpenAIGPTForSequenceClassification, OpenAIGPTLMHeadModel, OpenAIGPTModel -from ..pegasus.modeling_pegasus import PegasusForConditionalGeneration +from ..pegasus.modeling_pegasus import PegasusForConditionalGeneration, PegasusModel from ..prophetnet.modeling_prophetnet import ProphetNetForCausalLM, ProphetNetForConditionalGeneration, ProphetNetModel from ..rag.modeling_rag import ( # noqa: F401 - need to import all RagModels to be in globals() function RagModel, @@ -255,6 +255,10 @@ MODEL_MAPPING = OrderedDict( (RetriBertConfig, RetriBertModel), (MT5Config, MT5Model), (T5Config, T5Model), + (PegasusConfig, PegasusModel), + (MarianConfig, MarianMTModel), + (MBartConfig, MBartModel), + (BlenderbotConfig, BlenderbotModel), (DistilBertConfig, DistilBertModel), (AlbertConfig, AlbertModel), (CamembertConfig, CamembertModel), diff --git a/src/transformers/models/blenderbot/__init__.py b/src/transformers/models/blenderbot/__init__.py index fdcd990ff93..fccb38f80ac 100644 --- a/src/transformers/models/blenderbot/__init__.py +++ b/src/transformers/models/blenderbot/__init__.py @@ -22,7 +22,11 @@ from .tokenization_blenderbot import BlenderbotSmallTokenizer, BlenderbotTokeniz if is_torch_available(): - from .modeling_blenderbot import BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST, BlenderbotForConditionalGeneration + from .modeling_blenderbot import ( + BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST, + BlenderbotForConditionalGeneration, + BlenderbotModel, + ) if is_tf_available(): from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 1421a87ca9b..2a370fbabf8 100644 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -19,7 +19,7 @@ import torch from ...file_utils import add_start_docstrings -from ..bart.modeling_bart import BartForConditionalGeneration +from ..bart.modeling_bart import BartForConditionalGeneration, BartModel from .configuration_blenderbot import BlenderbotConfig @@ -39,7 +39,20 @@ BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST = ["facebook/blenderbot-3B", "facebook/ @add_start_docstrings( - "The BART Model with a language modeling head. Can be used for summarization.", BLENDER_START_DOCSTRING + "The bare BlenderBot Model transformer outputting raw hidden-states without any specific head on top.", + BLENDER_START_DOCSTRING, +) +class BlenderbotModel(BartModel): + r""" + This class overrides :class:`~transformers.BartModel`. Please check the superclass for the appropriate + documentation alongside usage examples. + """ + + config_class = BlenderbotConfig + + +@add_start_docstrings( + "The BlenderBot Model with a language modeling head. Can be used for summarization.", BLENDER_START_DOCSTRING ) class BlenderbotForConditionalGeneration(BartForConditionalGeneration): """ diff --git a/src/transformers/models/mbart/__init__.py b/src/transformers/models/mbart/__init__.py index b98d2266250..2fa8876085e 100644 --- a/src/transformers/models/mbart/__init__.py +++ b/src/transformers/models/mbart/__init__.py @@ -27,7 +27,7 @@ if is_tokenizers_available(): from .tokenization_mbart_fast import MBartTokenizerFast if is_torch_available(): - from .modeling_mbart import MBartForConditionalGeneration + from .modeling_mbart import MBartForConditionalGeneration, MBartModel if is_tf_available(): from .modeling_tf_mbart import TFMBartForConditionalGeneration diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 9fca52c5495..f4aa39b0751 100644 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ..bart.modeling_bart import BartForConditionalGeneration +from ..bart.modeling_bart import BartForConditionalGeneration, BartModel from .configuration_mbart import MBartConfig @@ -26,6 +26,23 @@ MBART_PRETRAINED_MODEL_ARCHIVE_LIST = [ ] +class MBartModel(BartModel): + r""" + This class overrides :class:`~transformers.BartModel`. Please check the superclass for the appropriate + documentation alongside usage examples. + """ + + config_class = MBartConfig + _keys_to_ignore_on_load_missing = [ + "encoder.embed_positions.weight", + "decoder.embed_positions.weight", + ] + _keys_to_ignore_on_save = [ + "encoder.embed_positions.weight", + "decoder.embed_positions.weight", + ] + + class MBartForConditionalGeneration(BartForConditionalGeneration): r""" This class overrides :class:`~transformers.BartForConditionalGeneration`. Please check the superclass for the diff --git a/src/transformers/models/pegasus/__init__.py b/src/transformers/models/pegasus/__init__.py index e7cc0ce71be..20d1c3872dc 100644 --- a/src/transformers/models/pegasus/__init__.py +++ b/src/transformers/models/pegasus/__init__.py @@ -27,7 +27,7 @@ if is_tokenizers_available(): from .tokenization_pegasus_fast import PegasusTokenizerFast if is_torch_available(): - from .modeling_pegasus import PegasusForConditionalGeneration + from .modeling_pegasus import PegasusForConditionalGeneration, PegasusModel if is_tf_available(): from .modeling_tf_pegasus import TFPegasusForConditionalGeneration diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 3e623a77040..c7fde416433 100644 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -16,10 +16,34 @@ from ...file_utils import add_start_docstrings -from ..bart.modeling_bart import BART_START_DOCSTRING, BartForConditionalGeneration +from ..bart.modeling_bart import BART_START_DOCSTRING, BartForConditionalGeneration, BartModel from .configuration_pegasus import PegasusConfig +@add_start_docstrings( + "The bare Pegasus Model transformer outputting raw hidden-states without any specific head on top.", + BART_START_DOCSTRING, +) +class PegasusModel(BartModel): + r""" + This class overrides :class:`~transformers.BartModel`. Please check the superclass for the appropriate + documentation alongside usage examples. + """ + + config_class = PegasusConfig + _keys_to_ignore_on_load_missing = [ + r"final_logits_bias", + r"encoder\.version", + r"decoder\.version", + "encoder.embed_positions", + "decoder.embed_positions", + ] + _keys_to_ignore_on_save = [ + "encoder.embed_positions.weight", + "decoder.embed_positions.weight", + ] + + @add_start_docstrings("The Pegasus Model for summarization ", BART_START_DOCSTRING) class PegasusForConditionalGeneration(BartForConditionalGeneration): r""" diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 050c7ba4f90..97669eff742 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -600,6 +600,15 @@ class BlenderbotForConditionalGeneration: requires_pytorch(self) +class BlenderbotModel: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None @@ -1297,6 +1306,15 @@ class MBartForConditionalGeneration: requires_pytorch(self) +class MBartModel: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + class MMBTForClassification: def __init__(self, *args, **kwargs): requires_pytorch(self) @@ -1560,6 +1578,15 @@ class PegasusForConditionalGeneration: requires_pytorch(self) +class PegasusModel: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/test_modeling_blenderbot.py b/tests/test_modeling_blenderbot.py index b069ba6089b..668569a5955 100644 --- a/tests/test_modeling_blenderbot.py +++ b/tests/test_modeling_blenderbot.py @@ -32,6 +32,7 @@ if is_torch_available(): AutoTokenizer, BlenderbotConfig, BlenderbotForConditionalGeneration, + BlenderbotModel, BlenderbotSmallTokenizer, BlenderbotTokenizer, ) @@ -90,7 +91,7 @@ class BlenderbotModelTester: class BlenderbotTesterMixin(ModelTesterMixin, unittest.TestCase): if is_torch_available(): all_generative_model_classes = (BlenderbotForConditionalGeneration,) - all_model_classes = (BlenderbotForConditionalGeneration,) + all_model_classes = (BlenderbotForConditionalGeneration, BlenderbotModel) else: all_generative_model_classes = () all_model_classes = () diff --git a/tests/test_modeling_mbart.py b/tests/test_modeling_mbart.py index 1a4094ed2ce..2a43650febb 100644 --- a/tests/test_modeling_mbart.py +++ b/tests/test_modeling_mbart.py @@ -30,6 +30,7 @@ if is_torch_available(): BatchEncoding, MBartConfig, MBartForConditionalGeneration, + MBartModel, ) @@ -59,7 +60,7 @@ class ModelTester: @require_torch class SelectiveCommonTest(unittest.TestCase): - all_model_classes = (MBartForConditionalGeneration,) if is_torch_available() else () + all_model_classes = (MBartForConditionalGeneration, MBartModel) if is_torch_available() else () test_save_load__keys_to_ignore_on_save = ModelTesterMixin.test_save_load__keys_to_ignore_on_save diff --git a/tests/test_modeling_pegasus.py b/tests/test_modeling_pegasus.py index 42173ebccfb..dc9fdf52254 100644 --- a/tests/test_modeling_pegasus.py +++ b/tests/test_modeling_pegasus.py @@ -26,7 +26,7 @@ from .test_modeling_mbart import AbstractSeq2SeqIntegrationTest if is_torch_available(): - from transformers import AutoModelForSeq2SeqLM, PegasusConfig, PegasusForConditionalGeneration + from transformers import AutoModelForSeq2SeqLM, PegasusConfig, PegasusForConditionalGeneration, PegasusModel XSUM_ENTRY_LONGER = """ The London trio are up for best UK act and best album, as well as getting two nominations in the best song category."We got told like this morning 'Oh I think you're nominated'", said Dappy."And I was like 'Oh yeah, which one?' And now we've got nominated for four awards. I mean, wow!"Bandmate Fazer added: "We thought it's best of us to come down and mingle with everyone and say hello to the cameras. And now we find we've got four nominations."The band have two shots at the best song prize, getting the nod for their Tynchy Stryder collaboration Number One, and single Strong Again.Their album Uncle B will also go up against records by the likes of Beyonce and Kanye West.N-Dubz picked up the best newcomer Mobo in 2007, but female member Tulisa said they wouldn't be too disappointed if they didn't win this time around."At the end of the day we're grateful to be where we are in our careers."If it don't happen then it don't happen - live to fight another day and keep on making albums and hits for the fans."Dappy also revealed they could be performing live several times on the night.The group will be doing Number One and also a possible rendition of the War Child single, I Got Soul.The charity song is a re-working of The Killers' All These Things That I've Done and is set to feature artists like Chipmunk, Ironik and Pixie Lott.This year's Mobos will be held outside of London for the first time, in Glasgow on 30 September.N-Dubz said they were looking forward to performing for their Scottish fans and boasted about their recent shows north of the border."We just done Edinburgh the other day," said Dappy."We smashed up an N-Dubz show over there. We done Aberdeen about three or four months ago - we smashed up that show over there! Everywhere we go we smash it up!" """ @@ -55,7 +55,7 @@ class ModelTester: @require_torch class SelectiveCommonTest(unittest.TestCase): - all_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else () + all_model_classes = (PegasusForConditionalGeneration, PegasusModel) if is_torch_available() else () test_save_load__keys_to_ignore_on_save = ModelTesterMixin.test_save_load__keys_to_ignore_on_save