diff --git a/docs/source/en/model_doc/biogpt.mdx b/docs/source/en/model_doc/biogpt.mdx index b334a05ede6..37b23402c26 100644 --- a/docs/source/en/model_doc/biogpt.mdx +++ b/docs/source/en/model_doc/biogpt.mdx @@ -54,8 +54,15 @@ This model was contributed by [kamalkraj](https://huggingface.co/kamalkraj). The [[autodoc]] BioGptForCausalLM - forward + ## BioGptForTokenClassification [[autodoc]] BioGptForTokenClassification + - forward + + +## BioGptForSequenceClassification + +[[autodoc]] BioGptForSequenceClassification - forward \ No newline at end of file diff --git a/docs/source/en/tasks/sequence_classification.mdx b/docs/source/en/tasks/sequence_classification.mdx index 6c062dcf934..fa15b5be30b 100644 --- a/docs/source/en/tasks/sequence_classification.mdx +++ b/docs/source/en/tasks/sequence_classification.mdx @@ -28,7 +28,7 @@ The task illustrated in this tutorial is supported by the following model archit -[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso) +[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 7da443b6705..47ca7f62a05 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1147,6 +1147,7 @@ else: [ "BIOGPT_PRETRAINED_MODEL_ARCHIVE_LIST", "BioGptForCausalLM", + "BioGptForSequenceClassification", "BioGptForTokenClassification", "BioGptModel", "BioGptPreTrainedModel", @@ -4792,6 +4793,7 @@ if TYPE_CHECKING: from .models.biogpt import ( BIOGPT_PRETRAINED_MODEL_ARCHIVE_LIST, BioGptForCausalLM, + BioGptForSequenceClassification, BioGptForTokenClassification, BioGptModel, BioGptPreTrainedModel, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index e6bc37d7371..15f7e01759f 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -648,6 +648,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ("bert", "BertForSequenceClassification"), ("big_bird", "BigBirdForSequenceClassification"), ("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"), + ("biogpt", "BioGptForSequenceClassification"), ("bloom", "BloomForSequenceClassification"), ("camembert", "CamembertForSequenceClassification"), ("canine", "CanineForSequenceClassification"), diff --git a/src/transformers/models/biogpt/__init__.py b/src/transformers/models/biogpt/__init__.py index 761b904013c..ec3d6966ac4 100644 --- a/src/transformers/models/biogpt/__init__.py +++ b/src/transformers/models/biogpt/__init__.py @@ -31,6 +31,7 @@ else: "BIOGPT_PRETRAINED_MODEL_ARCHIVE_LIST", "BioGptForCausalLM", "BioGptForTokenClassification", + "BioGptForSequenceClassification", "BioGptModel", "BioGptPreTrainedModel", ] @@ -49,6 +50,7 @@ if TYPE_CHECKING: from .modeling_biogpt import ( BIOGPT_PRETRAINED_MODEL_ARCHIVE_LIST, BioGptForCausalLM, + BioGptForSequenceClassification, BioGptForTokenClassification, BioGptModel, BioGptPreTrainedModel, diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 59b7f1f4ac7..4fc7f7b4893 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -22,16 +22,22 @@ from typing import Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, + SequenceClassifierOutputWithPast, TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) from .configuration_biogpt import BioGptConfig @@ -40,8 +46,10 @@ logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "microsoft/biogpt" _CONFIG_FOR_DOC = "BioGptConfig" + BIOGPT_PRETRAINED_MODEL_ARCHIVE_LIST = [ "microsoft/biogpt", + "microsoft/BioGPT-Large", # See all BioGPT models at https://huggingface.co/models?filter=biogpt ] @@ -832,3 +840,129 @@ class BioGptForTokenClassification(BioGptPreTrainedModel): hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) + + +@add_start_docstrings( + """ + The BioGpt Model transformer with a sequence classification head on top (linear layer). + + [`BioGptForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it is required to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + BIOGPT_START_DOCSTRING, +) +class BioGptForSequenceClassification(BioGptPreTrainedModel): + def __init__(self, config: BioGptConfig): + super().__init__(config) + self.num_labels = config.num_labels + self.biogpt = BioGptModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BIOGPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.biogpt( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + if self.config.pad_token_id is None: + sequence_length = -1 + else: + if input_ids is not None: + sequence_length = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_length = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_length] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + def get_input_embeddings(self): + return self.biogpt.embed_tokens + + def set_input_embeddings(self, value): + self.biogpt.embed_tokens = value diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 4b862f08bf8..05267a83dc7 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1103,6 +1103,13 @@ class BioGptForCausalLM(metaclass=DummyObject): requires_backends(self, ["torch"]) +class BioGptForSequenceClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class BioGptForTokenClassification(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/biogpt/test_modeling_biogpt.py b/tests/models/biogpt/test_modeling_biogpt.py index 4ad2edfc13f..7e64fc07dad 100644 --- a/tests/models/biogpt/test_modeling_biogpt.py +++ b/tests/models/biogpt/test_modeling_biogpt.py @@ -29,7 +29,13 @@ from ...test_pipeline_mixin import PipelineTesterMixin if is_torch_available(): import torch - from transformers import BioGptForCausalLM, BioGptForTokenClassification, BioGptModel, BioGptTokenizer + from transformers import ( + BioGptForCausalLM, + BioGptForSequenceClassification, + BioGptForTokenClassification, + BioGptModel, + BioGptTokenizer, + ) from transformers.models.biogpt.modeling_biogpt import BIOGPT_PRETRAINED_MODEL_ARCHIVE_LIST @@ -274,13 +280,18 @@ class BioGptModelTester: @require_torch class BioGptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): - all_model_classes = (BioGptModel, BioGptForCausalLM, BioGptForTokenClassification) if is_torch_available() else () + all_model_classes = ( + (BioGptModel, BioGptForCausalLM, BioGptForSequenceClassification, BioGptForTokenClassification) + if is_torch_available() + else () + ) all_generative_model_classes = (BioGptForCausalLM,) if is_torch_available() else () pipeline_model_mapping = ( { "feature-extraction": BioGptModel, "text-generation": BioGptForCausalLM, "token-classification": BioGptForTokenClassification, + "text-classification": BioGptForSequenceClassification, } if is_torch_available() else {} @@ -374,6 +385,35 @@ class BioGptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix model = BioGptModel.from_pretrained(model_name) self.assertIsNotNone(model) + # Copied from tests.models.opt.test_modeling_opt.OPTModelTest with OPT->BioGpt, prepare_config_and_inputs-> prepare_config_and_inputs_for_common + def test_biogpt_sequence_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) + model = BioGptForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + # Copied from tests.models.opt.test_modeling_opt.OPTModelTest with OPT->BioGpt, prepare_config_and_inputs-> prepare_config_and_inputs_for_common + def test_biogpt_sequence_classification_model_for_multi_label(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + config.problem_type = "multi_label_classification" + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor( + [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size + ).to(torch.float) + model = BioGptForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + @require_torch class BioGptModelIntegrationTest(unittest.TestCase):