mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
add base model classes to bart subclassed models (#9230)
* add base model classes to bart subclassed models * add doc
This commit is contained in:
parent
08abdabda1
commit
f4432b7e01
@ -100,6 +100,15 @@ BlenderbotSmallTokenizer
|
||||
:members:
|
||||
|
||||
|
||||
BlenderbotModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
See :obj:`transformers.BartModel` for arguments to `forward` and `generate`
|
||||
|
||||
.. autoclass:: transformers.BlenderbotModel
|
||||
:members:
|
||||
|
||||
|
||||
BlenderbotForConditionalGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
@ -97,6 +97,13 @@ MBartTokenizerFast
|
||||
:members:
|
||||
|
||||
|
||||
MBartModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.MBartModel
|
||||
:members:
|
||||
|
||||
|
||||
MBartForConditionalGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
@ -119,6 +119,12 @@ PegasusTokenizerFast
|
||||
:members:
|
||||
|
||||
|
||||
PegasusModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.PegasusModel
|
||||
|
||||
|
||||
PegasusForConditionalGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
@ -406,7 +406,11 @@ if is_torch_available():
|
||||
BertGenerationEncoder,
|
||||
load_tf_weights_in_bert_generation,
|
||||
)
|
||||
from .models.blenderbot import BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST, BlenderbotForConditionalGeneration
|
||||
from .models.blenderbot import (
|
||||
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
BlenderbotForConditionalGeneration,
|
||||
BlenderbotModel,
|
||||
)
|
||||
from .models.camembert import (
|
||||
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
CamembertForCausalLM,
|
||||
@ -522,7 +526,7 @@ if is_torch_available():
|
||||
LxmertXLayer,
|
||||
)
|
||||
from .models.marian import MarianMTModel
|
||||
from .models.mbart import MBartForConditionalGeneration
|
||||
from .models.mbart import MBartForConditionalGeneration, MBartModel
|
||||
from .models.mmbt import MMBTForClassification, MMBTModel, ModalEmbeddings
|
||||
from .models.mobilebert import (
|
||||
MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
@ -559,7 +563,7 @@ if is_torch_available():
|
||||
OpenAIGPTPreTrainedModel,
|
||||
load_tf_weights_in_openai_gpt,
|
||||
)
|
||||
from .models.pegasus import PegasusForConditionalGeneration
|
||||
from .models.pegasus import PegasusForConditionalGeneration, PegasusModel
|
||||
from .models.prophetnet import (
|
||||
PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
ProphetNetDecoder,
|
||||
|
@ -50,7 +50,7 @@ from ..bert.modeling_bert import (
|
||||
BertModel,
|
||||
)
|
||||
from ..bert_generation.modeling_bert_generation import BertGenerationDecoder, BertGenerationEncoder
|
||||
from ..blenderbot.modeling_blenderbot import BlenderbotForConditionalGeneration
|
||||
from ..blenderbot.modeling_blenderbot import BlenderbotForConditionalGeneration, BlenderbotModel
|
||||
from ..camembert.modeling_camembert import (
|
||||
CamembertForCausalLM,
|
||||
CamembertForMaskedLM,
|
||||
@ -111,7 +111,7 @@ from ..longformer.modeling_longformer import (
|
||||
)
|
||||
from ..lxmert.modeling_lxmert import LxmertForPreTraining, LxmertForQuestionAnswering, LxmertModel
|
||||
from ..marian.modeling_marian import MarianMTModel
|
||||
from ..mbart.modeling_mbart import MBartForConditionalGeneration
|
||||
from ..mbart.modeling_mbart import MBartForConditionalGeneration, MBartModel
|
||||
from ..mobilebert.modeling_mobilebert import (
|
||||
MobileBertForMaskedLM,
|
||||
MobileBertForMultipleChoice,
|
||||
@ -132,7 +132,7 @@ from ..mpnet.modeling_mpnet import (
|
||||
)
|
||||
from ..mt5.modeling_mt5 import MT5ForConditionalGeneration, MT5Model
|
||||
from ..openai.modeling_openai import OpenAIGPTForSequenceClassification, OpenAIGPTLMHeadModel, OpenAIGPTModel
|
||||
from ..pegasus.modeling_pegasus import PegasusForConditionalGeneration
|
||||
from ..pegasus.modeling_pegasus import PegasusForConditionalGeneration, PegasusModel
|
||||
from ..prophetnet.modeling_prophetnet import ProphetNetForCausalLM, ProphetNetForConditionalGeneration, ProphetNetModel
|
||||
from ..rag.modeling_rag import ( # noqa: F401 - need to import all RagModels to be in globals() function
|
||||
RagModel,
|
||||
@ -255,6 +255,10 @@ MODEL_MAPPING = OrderedDict(
|
||||
(RetriBertConfig, RetriBertModel),
|
||||
(MT5Config, MT5Model),
|
||||
(T5Config, T5Model),
|
||||
(PegasusConfig, PegasusModel),
|
||||
(MarianConfig, MarianMTModel),
|
||||
(MBartConfig, MBartModel),
|
||||
(BlenderbotConfig, BlenderbotModel),
|
||||
(DistilBertConfig, DistilBertModel),
|
||||
(AlbertConfig, AlbertModel),
|
||||
(CamembertConfig, CamembertModel),
|
||||
|
@ -22,7 +22,11 @@ from .tokenization_blenderbot import BlenderbotSmallTokenizer, BlenderbotTokeniz
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_blenderbot import BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST, BlenderbotForConditionalGeneration
|
||||
from .modeling_blenderbot import (
|
||||
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
BlenderbotForConditionalGeneration,
|
||||
BlenderbotModel,
|
||||
)
|
||||
|
||||
if is_tf_available():
|
||||
from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration
|
||||
|
@ -19,7 +19,7 @@
|
||||
import torch
|
||||
|
||||
from ...file_utils import add_start_docstrings
|
||||
from ..bart.modeling_bart import BartForConditionalGeneration
|
||||
from ..bart.modeling_bart import BartForConditionalGeneration, BartModel
|
||||
from .configuration_blenderbot import BlenderbotConfig
|
||||
|
||||
|
||||
@ -39,7 +39,20 @@ BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST = ["facebook/blenderbot-3B", "facebook/
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The BART Model with a language modeling head. Can be used for summarization.", BLENDER_START_DOCSTRING
|
||||
"The bare BlenderBot Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
BLENDER_START_DOCSTRING,
|
||||
)
|
||||
class BlenderbotModel(BartModel):
|
||||
r"""
|
||||
This class overrides :class:`~transformers.BartModel`. Please check the superclass for the appropriate
|
||||
documentation alongside usage examples.
|
||||
"""
|
||||
|
||||
config_class = BlenderbotConfig
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The BlenderBot Model with a language modeling head. Can be used for summarization.", BLENDER_START_DOCSTRING
|
||||
)
|
||||
class BlenderbotForConditionalGeneration(BartForConditionalGeneration):
|
||||
"""
|
||||
|
@ -27,7 +27,7 @@ if is_tokenizers_available():
|
||||
from .tokenization_mbart_fast import MBartTokenizerFast
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_mbart import MBartForConditionalGeneration
|
||||
from .modeling_mbart import MBartForConditionalGeneration, MBartModel
|
||||
|
||||
if is_tf_available():
|
||||
from .modeling_tf_mbart import TFMBartForConditionalGeneration
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..bart.modeling_bart import BartForConditionalGeneration
|
||||
from ..bart.modeling_bart import BartForConditionalGeneration, BartModel
|
||||
from .configuration_mbart import MBartConfig
|
||||
|
||||
|
||||
@ -26,6 +26,23 @@ MBART_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
]
|
||||
|
||||
|
||||
class MBartModel(BartModel):
|
||||
r"""
|
||||
This class overrides :class:`~transformers.BartModel`. Please check the superclass for the appropriate
|
||||
documentation alongside usage examples.
|
||||
"""
|
||||
|
||||
config_class = MBartConfig
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
"encoder.embed_positions.weight",
|
||||
"decoder.embed_positions.weight",
|
||||
]
|
||||
_keys_to_ignore_on_save = [
|
||||
"encoder.embed_positions.weight",
|
||||
"decoder.embed_positions.weight",
|
||||
]
|
||||
|
||||
|
||||
class MBartForConditionalGeneration(BartForConditionalGeneration):
|
||||
r"""
|
||||
This class overrides :class:`~transformers.BartForConditionalGeneration`. Please check the superclass for the
|
||||
|
@ -27,7 +27,7 @@ if is_tokenizers_available():
|
||||
from .tokenization_pegasus_fast import PegasusTokenizerFast
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_pegasus import PegasusForConditionalGeneration
|
||||
from .modeling_pegasus import PegasusForConditionalGeneration, PegasusModel
|
||||
|
||||
if is_tf_available():
|
||||
from .modeling_tf_pegasus import TFPegasusForConditionalGeneration
|
||||
|
@ -16,10 +16,34 @@
|
||||
|
||||
|
||||
from ...file_utils import add_start_docstrings
|
||||
from ..bart.modeling_bart import BART_START_DOCSTRING, BartForConditionalGeneration
|
||||
from ..bart.modeling_bart import BART_START_DOCSTRING, BartForConditionalGeneration, BartModel
|
||||
from .configuration_pegasus import PegasusConfig
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Pegasus Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
BART_START_DOCSTRING,
|
||||
)
|
||||
class PegasusModel(BartModel):
|
||||
r"""
|
||||
This class overrides :class:`~transformers.BartModel`. Please check the superclass for the appropriate
|
||||
documentation alongside usage examples.
|
||||
"""
|
||||
|
||||
config_class = PegasusConfig
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
r"final_logits_bias",
|
||||
r"encoder\.version",
|
||||
r"decoder\.version",
|
||||
"encoder.embed_positions",
|
||||
"decoder.embed_positions",
|
||||
]
|
||||
_keys_to_ignore_on_save = [
|
||||
"encoder.embed_positions.weight",
|
||||
"decoder.embed_positions.weight",
|
||||
]
|
||||
|
||||
|
||||
@add_start_docstrings("The Pegasus Model for summarization ", BART_START_DOCSTRING)
|
||||
class PegasusForConditionalGeneration(BartForConditionalGeneration):
|
||||
r"""
|
||||
|
@ -600,6 +600,15 @@ class BlenderbotForConditionalGeneration:
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class BlenderbotModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
@ -1297,6 +1306,15 @@ class MBartForConditionalGeneration:
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class MBartModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class MMBTForClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
@ -1560,6 +1578,15 @@ class PegasusForConditionalGeneration:
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class PegasusModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
@ -32,6 +32,7 @@ if is_torch_available():
|
||||
AutoTokenizer,
|
||||
BlenderbotConfig,
|
||||
BlenderbotForConditionalGeneration,
|
||||
BlenderbotModel,
|
||||
BlenderbotSmallTokenizer,
|
||||
BlenderbotTokenizer,
|
||||
)
|
||||
@ -90,7 +91,7 @@ class BlenderbotModelTester:
|
||||
class BlenderbotTesterMixin(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available():
|
||||
all_generative_model_classes = (BlenderbotForConditionalGeneration,)
|
||||
all_model_classes = (BlenderbotForConditionalGeneration,)
|
||||
all_model_classes = (BlenderbotForConditionalGeneration, BlenderbotModel)
|
||||
else:
|
||||
all_generative_model_classes = ()
|
||||
all_model_classes = ()
|
||||
|
@ -30,6 +30,7 @@ if is_torch_available():
|
||||
BatchEncoding,
|
||||
MBartConfig,
|
||||
MBartForConditionalGeneration,
|
||||
MBartModel,
|
||||
)
|
||||
|
||||
|
||||
@ -59,7 +60,7 @@ class ModelTester:
|
||||
|
||||
@require_torch
|
||||
class SelectiveCommonTest(unittest.TestCase):
|
||||
all_model_classes = (MBartForConditionalGeneration,) if is_torch_available() else ()
|
||||
all_model_classes = (MBartForConditionalGeneration, MBartModel) if is_torch_available() else ()
|
||||
|
||||
test_save_load__keys_to_ignore_on_save = ModelTesterMixin.test_save_load__keys_to_ignore_on_save
|
||||
|
||||
|
@ -26,7 +26,7 @@ from .test_modeling_mbart import AbstractSeq2SeqIntegrationTest
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import AutoModelForSeq2SeqLM, PegasusConfig, PegasusForConditionalGeneration
|
||||
from transformers import AutoModelForSeq2SeqLM, PegasusConfig, PegasusForConditionalGeneration, PegasusModel
|
||||
|
||||
XSUM_ENTRY_LONGER = """ The London trio are up for best UK act and best album, as well as getting two nominations in the best song category."We got told like this morning 'Oh I think you're nominated'", said Dappy."And I was like 'Oh yeah, which one?' And now we've got nominated for four awards. I mean, wow!"Bandmate Fazer added: "We thought it's best of us to come down and mingle with everyone and say hello to the cameras. And now we find we've got four nominations."The band have two shots at the best song prize, getting the nod for their Tynchy Stryder collaboration Number One, and single Strong Again.Their album Uncle B will also go up against records by the likes of Beyonce and Kanye West.N-Dubz picked up the best newcomer Mobo in 2007, but female member Tulisa said they wouldn't be too disappointed if they didn't win this time around."At the end of the day we're grateful to be where we are in our careers."If it don't happen then it don't happen - live to fight another day and keep on making albums and hits for the fans."Dappy also revealed they could be performing live several times on the night.The group will be doing Number One and also a possible rendition of the War Child single, I Got Soul.The charity song is a re-working of The Killers' All These Things That I've Done and is set to feature artists like Chipmunk, Ironik and Pixie Lott.This year's Mobos will be held outside of London for the first time, in Glasgow on 30 September.N-Dubz said they were looking forward to performing for their Scottish fans and boasted about their recent shows north of the border."We just done Edinburgh the other day," said Dappy."We smashed up an N-Dubz show over there. We done Aberdeen about three or four months ago - we smashed up that show over there! Everywhere we go we smash it up!" """
|
||||
|
||||
@ -55,7 +55,7 @@ class ModelTester:
|
||||
|
||||
@require_torch
|
||||
class SelectiveCommonTest(unittest.TestCase):
|
||||
all_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else ()
|
||||
all_model_classes = (PegasusForConditionalGeneration, PegasusModel) if is_torch_available() else ()
|
||||
|
||||
test_save_load__keys_to_ignore_on_save = ModelTesterMixin.test_save_load__keys_to_ignore_on_save
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user