mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Adds LlamaForQuestionAnswering class in modeling_llama.py along with AutoModel Support (#28777)
* This is a test commit * testing commit * final commit with some changes * Removed copy statement * Fixed formatting issues * Fixed error added past_key_values in the forward method * Fixed a trailing whitespace. Damn the formatting rules are strict * Added the copy statement
This commit is contained in:
parent
ac51e59e47
commit
2e7c942c81
@ -116,6 +116,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
|||||||
[[autodoc]] LlamaForSequenceClassification
|
[[autodoc]] LlamaForSequenceClassification
|
||||||
- forward
|
- forward
|
||||||
|
|
||||||
|
## LlamaForQuestionAnswering
|
||||||
|
|
||||||
|
[[autodoc]] LlamaForQuestionAnswering
|
||||||
|
- forward
|
||||||
|
|
||||||
## FlaxLlamaModel
|
## FlaxLlamaModel
|
||||||
|
|
||||||
[[autodoc]] FlaxLlamaModel
|
[[autodoc]] FlaxLlamaModel
|
||||||
|
@ -36,7 +36,7 @@ The task illustrated in this tutorial is supported by the following model archit
|
|||||||
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
|
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
|
||||||
|
|
||||||
|
|
||||||
[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [Falcon](../model_doc/falcon), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [OpenAI GPT-2](../model_doc/gpt2), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [LXMERT](../model_doc/lxmert), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MPT](../model_doc/mpt), [MRA](../model_doc/mra), [MT5](../model_doc/mt5), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OPT](../model_doc/opt), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [Splinter](../model_doc/splinter), [SqueezeBERT](../model_doc/squeezebert), [T5](../model_doc/t5), [UMT5](../model_doc/umt5), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
|
[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [Falcon](../model_doc/falcon), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [OpenAI GPT-2](../model_doc/gpt2), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [LXMERT](../model_doc/lxmert), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MPT](../model_doc/mpt), [MRA](../model_doc/mra), [MT5](../model_doc/mt5), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OPT](../model_doc/opt), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [Splinter](../model_doc/splinter), [SqueezeBERT](../model_doc/squeezebert), [T5](../model_doc/t5), [UMT5](../model_doc/umt5), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
|
||||||
|
|
||||||
|
|
||||||
<!--End of the generated tip-->
|
<!--End of the generated tip-->
|
||||||
|
@ -2483,6 +2483,7 @@ else:
|
|||||||
_import_structure["models.llama"].extend(
|
_import_structure["models.llama"].extend(
|
||||||
[
|
[
|
||||||
"LlamaForCausalLM",
|
"LlamaForCausalLM",
|
||||||
|
"LlamaForQuestionAnswering",
|
||||||
"LlamaForSequenceClassification",
|
"LlamaForSequenceClassification",
|
||||||
"LlamaModel",
|
"LlamaModel",
|
||||||
"LlamaPreTrainedModel",
|
"LlamaPreTrainedModel",
|
||||||
@ -7025,7 +7026,13 @@ if TYPE_CHECKING:
|
|||||||
LiltModel,
|
LiltModel,
|
||||||
LiltPreTrainedModel,
|
LiltPreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel
|
from .models.llama import (
|
||||||
|
LlamaForCausalLM,
|
||||||
|
LlamaForQuestionAnswering,
|
||||||
|
LlamaForSequenceClassification,
|
||||||
|
LlamaModel,
|
||||||
|
LlamaPreTrainedModel,
|
||||||
|
)
|
||||||
from .models.llava import (
|
from .models.llava import (
|
||||||
LLAVA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
LLAVA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
LlavaForConditionalGeneration,
|
LlavaForConditionalGeneration,
|
||||||
|
@ -849,6 +849,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
|||||||
("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
|
("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
|
||||||
("led", "LEDForQuestionAnswering"),
|
("led", "LEDForQuestionAnswering"),
|
||||||
("lilt", "LiltForQuestionAnswering"),
|
("lilt", "LiltForQuestionAnswering"),
|
||||||
|
("llama", "LlamaForQuestionAnswering"),
|
||||||
("longformer", "LongformerForQuestionAnswering"),
|
("longformer", "LongformerForQuestionAnswering"),
|
||||||
("luke", "LukeForQuestionAnswering"),
|
("luke", "LukeForQuestionAnswering"),
|
||||||
("lxmert", "LxmertForQuestionAnswering"),
|
("lxmert", "LxmertForQuestionAnswering"),
|
||||||
|
@ -54,6 +54,7 @@ else:
|
|||||||
"LlamaModel",
|
"LlamaModel",
|
||||||
"LlamaPreTrainedModel",
|
"LlamaPreTrainedModel",
|
||||||
"LlamaForSequenceClassification",
|
"LlamaForSequenceClassification",
|
||||||
|
"LlamaForQuestionAnswering",
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -90,7 +91,13 @@ if TYPE_CHECKING:
|
|||||||
except OptionalDependencyNotAvailable:
|
except OptionalDependencyNotAvailable:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
from .modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel
|
from .modeling_llama import (
|
||||||
|
LlamaForCausalLM,
|
||||||
|
LlamaForQuestionAnswering,
|
||||||
|
LlamaForSequenceClassification,
|
||||||
|
LlamaModel,
|
||||||
|
LlamaPreTrainedModel,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_flax_available():
|
if not is_flax_available():
|
||||||
|
@ -36,7 +36,12 @@ from ...modeling_attn_mask_utils import (
|
|||||||
_prepare_4d_causal_attention_mask,
|
_prepare_4d_causal_attention_mask,
|
||||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||||
)
|
)
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
from ...modeling_outputs import (
|
||||||
|
BaseModelOutputWithPast,
|
||||||
|
CausalLMOutputWithPast,
|
||||||
|
QuestionAnsweringModelOutput,
|
||||||
|
SequenceClassifierOutputWithPast,
|
||||||
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
|
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
@ -1413,3 +1418,100 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
|
|||||||
hidden_states=transformer_outputs.hidden_states,
|
hidden_states=transformer_outputs.hidden_states,
|
||||||
attentions=transformer_outputs.attentions,
|
attentions=transformer_outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
The Llama Model transformer with a span classification head on top for extractive question-answering tasks like
|
||||||
|
SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
||||||
|
""",
|
||||||
|
LLAMA_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class LlamaForQuestionAnswering(LlamaPreTrainedModel):
|
||||||
|
# Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.transformer = LlamaModel(config)
|
||||||
|
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.transformer.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.transformer.embed_tokens = value
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
start_positions: Optional[torch.LongTensor] = None,
|
||||||
|
end_positions: Optional[torch.LongTensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
||||||
|
r"""
|
||||||
|
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||||
|
are not taken into account for computing the loss.
|
||||||
|
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||||
|
are not taken into account for computing the loss.
|
||||||
|
"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
outputs = self.transformer(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
|
||||||
|
logits = self.qa_outputs(sequence_output)
|
||||||
|
start_logits, end_logits = logits.split(1, dim=-1)
|
||||||
|
start_logits = start_logits.squeeze(-1).contiguous()
|
||||||
|
end_logits = end_logits.squeeze(-1).contiguous()
|
||||||
|
|
||||||
|
total_loss = None
|
||||||
|
if start_positions is not None and end_positions is not None:
|
||||||
|
# If we are on multi-GPU, split add a dimension
|
||||||
|
if len(start_positions.size()) > 1:
|
||||||
|
start_positions = start_positions.squeeze(-1).to(start_logits.device)
|
||||||
|
if len(end_positions.size()) > 1:
|
||||||
|
end_positions = end_positions.squeeze(-1).to(end_logits.device)
|
||||||
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
|
ignored_index = start_logits.size(1)
|
||||||
|
start_positions = start_positions.clamp(0, ignored_index)
|
||||||
|
end_positions = end_positions.clamp(0, ignored_index)
|
||||||
|
|
||||||
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
end_loss = loss_fct(end_logits, end_positions)
|
||||||
|
total_loss = (start_loss + end_loss) / 2
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (start_logits, end_logits) + outputs[2:]
|
||||||
|
return ((total_loss,) + output) if total_loss is not None else output
|
||||||
|
|
||||||
|
return QuestionAnsweringModelOutput(
|
||||||
|
loss=total_loss,
|
||||||
|
start_logits=start_logits,
|
||||||
|
end_logits=end_logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
@ -4689,6 +4689,13 @@ class LlamaForCausalLM(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaForQuestionAnswering(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class LlamaForSequenceClassification(metaclass=DummyObject):
|
class LlamaForSequenceClassification(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
@ -44,6 +44,7 @@ if is_torch_available():
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
CodeLlamaTokenizer,
|
CodeLlamaTokenizer,
|
||||||
LlamaForCausalLM,
|
LlamaForCausalLM,
|
||||||
|
LlamaForQuestionAnswering,
|
||||||
LlamaForSequenceClassification,
|
LlamaForSequenceClassification,
|
||||||
LlamaModel,
|
LlamaModel,
|
||||||
LlamaTokenizer,
|
LlamaTokenizer,
|
||||||
@ -278,7 +279,11 @@ class LlamaModelTester:
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (LlamaModel, LlamaForCausalLM, LlamaForSequenceClassification) if is_torch_available() else ()
|
all_model_classes = (
|
||||||
|
(LlamaModel, LlamaForCausalLM, LlamaForSequenceClassification, LlamaForQuestionAnswering)
|
||||||
|
if is_torch_available()
|
||||||
|
else ()
|
||||||
|
)
|
||||||
all_generative_model_classes = (LlamaForCausalLM,) if is_torch_available() else ()
|
all_generative_model_classes = (LlamaForCausalLM,) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{
|
{
|
||||||
@ -286,6 +291,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
"text-classification": LlamaForSequenceClassification,
|
"text-classification": LlamaForSequenceClassification,
|
||||||
"text-generation": LlamaForCausalLM,
|
"text-generation": LlamaForCausalLM,
|
||||||
"zero-shot": LlamaForSequenceClassification,
|
"zero-shot": LlamaForSequenceClassification,
|
||||||
|
"question-answering": LlamaForQuestionAnswering,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
|
Loading…
Reference in New Issue
Block a user