mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix OPTForQuestionAnswering
doctest (#19479)
* Fix doc example for OPTForQuestionAnswering Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
957ce6465a
commit
fa9e18c65f
@ -53,12 +53,6 @@ _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ArthurZ/opt-350m-dummy-sc"
|
||||
_SEQ_CLASS_EXPECTED_LOSS = 1.71
|
||||
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'"
|
||||
|
||||
# QuestionAnswering docstring
|
||||
_QA_EXPECTED_OUTPUT = "'a nice puppet'"
|
||||
_QA_EXPECTED_LOSS = 7.41
|
||||
_QA_TARGET_START_INDEX = 14
|
||||
_QA_TARGET_END_INDEX = 15
|
||||
|
||||
OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"facebook/opt-125m",
|
||||
"facebook/opt-350m",
|
||||
@ -1140,16 +1134,7 @@ class OPTForQuestionAnswering(OPTPreTrainedModel):
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=QuestionAnsweringModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
qa_target_start_index=_QA_TARGET_START_INDEX,
|
||||
qa_target_end_index=_QA_TARGET_END_INDEX,
|
||||
expected_output=_QA_EXPECTED_OUTPUT,
|
||||
expected_loss=_QA_EXPECTED_LOSS,
|
||||
)
|
||||
@replace_return_docstrings(output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
@ -1173,7 +1158,36 @@ class OPTForQuestionAnswering(OPTPreTrainedModel):
|
||||
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||
are not taken into account for computing the loss.
|
||||
"""
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import GPT2Tokenizer, OPTForQuestionAnswering
|
||||
>>> import torch
|
||||
|
||||
>>> torch.manual_seed(4) # doctest: +IGNORE_RESULT
|
||||
>>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
|
||||
|
||||
>>> # note: we are loading a OPTForQuestionAnswering from the hub here,
|
||||
>>> # so the head will be randomly initialized, hence the predictions will be random
|
||||
>>> model = OPTForQuestionAnswering.from_pretrained("facebook/opt-350m")
|
||||
|
||||
>>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
|
||||
|
||||
>>> inputs = tokenizer(question, text, return_tensors="pt")
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(**inputs)
|
||||
|
||||
>>> answer_start_index = outputs.start_logits.argmax()
|
||||
>>> answer_end_index = outputs.end_logits.argmax()
|
||||
|
||||
>>> predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
|
||||
>>> predicted = tokenizer.decode(predict_answer_tokens)
|
||||
>>> predicted
|
||||
' Henson?'
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = self.model(
|
||||
|
Loading…
Reference in New Issue
Block a user