mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Add doc tests for Albert and Bigbird (#16774)
* Add doctest BERT * make fixup * fix typo * change checkpoints * make fixup * define doctest output value, update doctest for mobilebert * solve fix-copies * update QA target start index and end index * change checkpoint for docs and reuse defined variable * Update src/transformers/models/bert/modeling_tf_bert.py Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> * make fixup * Add Doctest for Albert and Bigbird * make fixup * overwrite examples for Albert and Bigbird * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * update longer examples for Bigbird * using examples from squad_v2 * print out example text * change name token-classification-big-bird checkpoint to random Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
9fa88172c2
commit
0d1cff1195
@ -801,9 +801,8 @@ class AlbertForPreTraining(AlbertPreTrainedModel):
|
||||
>>> tokenizer = AlbertTokenizer.from_pretrained("albert-base-v2")
|
||||
>>> model = AlbertForPreTraining.from_pretrained("albert-base-v2")
|
||||
|
||||
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(
|
||||
... 0
|
||||
>>> ) # Batch size 1
|
||||
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)
|
||||
>>> # Batch size 1
|
||||
>>> outputs = model(input_ids)
|
||||
|
||||
>>> prediction_logits = outputs.prediction_logits
|
||||
@ -914,12 +913,7 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
|
||||
return self.albert.embeddings.word_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=MaskedLMOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
@ -938,6 +932,37 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
|
||||
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
||||
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
||||
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import AlbertTokenizer, AlbertForMaskedLM
|
||||
|
||||
>>> tokenizer = AlbertTokenizer.from_pretrained("albert-base-v2")
|
||||
>>> model = AlbertForMaskedLM.from_pretrained("albert-base-v2")
|
||||
|
||||
>>> # add mask_token
|
||||
>>> inputs = tokenizer("The capital of [MASK] is Paris.", return_tensors="pt")
|
||||
>>> with torch.no_grad():
|
||||
... logits = model(**inputs).logits
|
||||
|
||||
>>> # retrieve index of [MASK]
|
||||
>>> mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
|
||||
>>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
|
||||
>>> tokenizer.decode(predicted_token_id)
|
||||
'france'
|
||||
```
|
||||
|
||||
```python
|
||||
>>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"]
|
||||
>>> labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
|
||||
>>> outputs = model(**inputs, labels=labels)
|
||||
>>> round(outputs.loss.item(), 2)
|
||||
0.81
|
||||
```
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
@ -996,9 +1021,11 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
|
||||
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
checkpoint="textattack/albert-base-v2-imdb",
|
||||
output_type=SequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
expected_output="'LABEL_1'",
|
||||
expected_loss=0.12,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
@ -1103,9 +1130,12 @@ class AlbertForTokenClassification(AlbertPreTrainedModel):
|
||||
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
checkpoint="vumichien/tiny-albert",
|
||||
output_type=TokenClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
expected_output="['LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_0', 'LABEL_1', 'LABEL_0', 'LABEL_1', 'LABEL_1', "
|
||||
"'LABEL_0', 'LABEL_1', 'LABEL_0', 'LABEL_0', 'LABEL_1', 'LABEL_1']",
|
||||
expected_loss=0.66,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
@ -1184,9 +1214,13 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
|
||||
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
checkpoint="twmkn9/albert-base-v2-squad2",
|
||||
output_type=QuestionAnsweringModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
qa_target_start_index=12,
|
||||
qa_target_end_index=13,
|
||||
expected_output="'a nice puppet'",
|
||||
expected_loss=7.36,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
|
@ -865,9 +865,8 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel, TFAlbertPreTrainingLoss):
|
||||
>>> tokenizer = AlbertTokenizer.from_pretrained("albert-base-v2")
|
||||
>>> model = TFAlbertForPreTraining.from_pretrained("albert-base-v2")
|
||||
|
||||
>>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[
|
||||
... None, :
|
||||
>>> ] # Batch size 1
|
||||
>>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :]
|
||||
>>> # Batch size 1
|
||||
>>> outputs = model(input_ids)
|
||||
|
||||
>>> prediction_logits = outputs.prediction_logits
|
||||
@ -954,12 +953,7 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss)
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=TFMaskedLMOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@replace_return_docstrings(output_type=TFMaskedLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def call(
|
||||
self,
|
||||
input_ids: Optional[TFModelInputType] = None,
|
||||
@ -979,6 +973,36 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss)
|
||||
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
||||
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
||||
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> import tensorflow as tf
|
||||
>>> from transformers import AlbertTokenizer, TFAlbertForMaskedLM
|
||||
|
||||
>>> tokenizer = AlbertTokenizer.from_pretrained("albert-base-v2")
|
||||
>>> model = TFAlbertForMaskedLM.from_pretrained("albert-base-v2")
|
||||
|
||||
>>> # add mask_token
|
||||
>>> inputs = tokenizer(f"The capital of [MASK] is Paris.", return_tensors="tf")
|
||||
>>> logits = model(**inputs).logits
|
||||
|
||||
>>> # retrieve index of [MASK]
|
||||
>>> mask_token_index = tf.where(inputs.input_ids == tokenizer.mask_token_id)[0][1]
|
||||
>>> predicted_token_id = tf.math.argmax(logits[0, mask_token_index], axis=-1)
|
||||
>>> tokenizer.decode(predicted_token_id)
|
||||
'france'
|
||||
```
|
||||
|
||||
```python
|
||||
>>> labels = tokenizer("The capital of France is Paris.", return_tensors="tf")["input_ids"]
|
||||
>>> labels = tf.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
|
||||
>>> outputs = model(**inputs, labels=labels)
|
||||
>>> round(float(outputs.loss), 2)
|
||||
0.81
|
||||
```
|
||||
"""
|
||||
outputs = self.albert(
|
||||
input_ids=input_ids,
|
||||
@ -1043,9 +1067,11 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
|
||||
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
checkpoint="vumichien/albert-base-v2-imdb",
|
||||
output_type=TFSequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
expected_output="'LABEL_1'",
|
||||
expected_loss=0.12,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
@ -1136,9 +1162,12 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
|
||||
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
checkpoint="vumichien/tiny-albert",
|
||||
output_type=TFTokenClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
expected_output="['LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_0', 'LABEL_1', 'LABEL_0', 'LABEL_1', 'LABEL_1', "
|
||||
"'LABEL_0', 'LABEL_1', 'LABEL_0', 'LABEL_0', 'LABEL_1', 'LABEL_1']",
|
||||
expected_loss=0.66,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
@ -1220,9 +1249,13 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
|
||||
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
checkpoint="vumichien/albert-base-v2-squad2",
|
||||
output_type=TFQuestionAnsweringModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
qa_target_start_index=12,
|
||||
qa_target_end_index=13,
|
||||
expected_output="'a nice puppet'",
|
||||
expected_loss=7.36,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
|
@ -2392,12 +2392,7 @@ class BigBirdForMaskedLM(BigBirdPreTrainedModel):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=MaskedLMOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
@ -2418,6 +2413,49 @@ class BigBirdForMaskedLM(BigBirdPreTrainedModel):
|
||||
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
||||
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
||||
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import BigBirdTokenizer, BigBirdForMaskedLM
|
||||
>>> from datasets import load_dataset
|
||||
|
||||
>>> tokenizer = BigBirdTokenizer.from_pretrained("google/bigbird-roberta-base")
|
||||
>>> model = BigBirdForMaskedLM.from_pretrained("google/bigbird-roberta-base")
|
||||
>>> squad_ds = load_dataset("squad_v2", split="train") # doctest: +IGNORE_RESULT
|
||||
|
||||
>>> # select random long article
|
||||
>>> LONG_ARTICLE_TARGET = squad_ds[81514]["context"]
|
||||
>>> # select random sentence
|
||||
>>> LONG_ARTICLE_TARGET[332:398]
|
||||
'the highest values are very close to the theoretical maximum value'
|
||||
|
||||
>>> # add mask_token
|
||||
>>> LONG_ARTICLE_TO_MASK = LONG_ARTICLE_TARGET.replace("maximum", "[MASK]")
|
||||
>>> inputs = tokenizer(LONG_ARTICLE_TO_MASK, return_tensors="pt")
|
||||
>>> # long article input
|
||||
>>> list(inputs["input_ids"].shape)
|
||||
[1, 919]
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... logits = model(**inputs).logits
|
||||
>>> # retrieve index of [MASK]
|
||||
>>> mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
|
||||
>>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
|
||||
>>> tokenizer.decode(predicted_token_id)
|
||||
'maximum'
|
||||
```
|
||||
|
||||
```python
|
||||
>>> labels = tokenizer(LONG_ARTICLE_TARGET, return_tensors="pt")["input_ids"]
|
||||
>>> labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
|
||||
>>> outputs = model(**inputs, labels=labels)
|
||||
>>> round(outputs.loss.item(), 2)
|
||||
1.08
|
||||
```
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
@ -2496,7 +2534,12 @@ class BigBirdForCausalLM(BigBirdPreTrainedModel):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=CausalLMOutputWithCrossAttentions,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
@ -2536,25 +2579,7 @@ class BigBirdForCausalLM(BigBirdPreTrainedModel):
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import BigBirdTokenizer, BigBirdForCausalLM, BigBirdConfig
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = BigBirdTokenizer.from_pretrained("google/bigbird-roberta-base")
|
||||
>>> config = BigBirdConfig.from_pretrained("google/bigbird-roberta-base")
|
||||
>>> config.is_decoder = True
|
||||
>>> model = BigBirdForCausalLM.from_pretrained("google/bigbird-roberta-base", config=config)
|
||||
|
||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
|
||||
>>> prediction_logits = outputs.logits
|
||||
```"""
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.bert(
|
||||
@ -2662,12 +2687,7 @@ class BigBirdForSequenceClassification(BigBirdPreTrainedModel):
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=SequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
@ -2686,6 +2706,43 @@ class BigBirdForSequenceClassification(BigBirdPreTrainedModel):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import BigBirdTokenizer, BigBirdForSequenceClassification
|
||||
>>> from datasets import load_dataset
|
||||
|
||||
>>> tokenizer = BigBirdTokenizer.from_pretrained("l-yohai/bigbird-roberta-base-mnli")
|
||||
>>> model = BigBirdForSequenceClassification.from_pretrained("l-yohai/bigbird-roberta-base-mnli")
|
||||
>>> squad_ds = load_dataset("squad_v2", split="train") # doctest: +IGNORE_RESULT
|
||||
|
||||
>>> LONG_ARTICLE = squad_ds[81514]["context"]
|
||||
>>> inputs = tokenizer(LONG_ARTICLE, return_tensors="pt")
|
||||
>>> # long input article
|
||||
>>> list(inputs["input_ids"].shape)
|
||||
[1, 919]
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... logits = model(**inputs).logits
|
||||
>>> predicted_class_id = logits.argmax().item()
|
||||
>>> model.config.id2label[predicted_class_id]
|
||||
'LABEL_0'
|
||||
```
|
||||
|
||||
```python
|
||||
>>> num_labels = len(model.config.id2label)
|
||||
>>> model = BigBirdForSequenceClassification.from_pretrained(
|
||||
... "l-yohai/bigbird-roberta-base-mnli", num_labels=num_labels
|
||||
... )
|
||||
>>> labels = torch.tensor(1)
|
||||
>>> loss = model(**inputs, labels=labels).loss
|
||||
>>> round(loss.item(), 2)
|
||||
1.13
|
||||
```
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
@ -2858,9 +2915,12 @@ class BigBirdForTokenClassification(BigBirdPreTrainedModel):
|
||||
@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
checkpoint="vumichien/token-classification-bigbird-roberta-base-random",
|
||||
output_type=TokenClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
expected_output="['LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', "
|
||||
"'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1']",
|
||||
expected_loss=0.54,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
@ -2955,12 +3015,7 @@ class BigBirdForQuestionAnswering(BigBirdPreTrainedModel):
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="google/bigbird-base-trivia-itc",
|
||||
output_type=BigBirdForQuestionAnsweringModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@replace_return_docstrings(output_type=BigBirdForQuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
@ -2985,6 +3040,48 @@ class BigBirdForQuestionAnswering(BigBirdPreTrainedModel):
|
||||
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||||
are not taken into account for computing the loss.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import BigBirdTokenizer, BigBirdForQuestionAnswering
|
||||
>>> from datasets import load_dataset
|
||||
|
||||
>>> tokenizer = BigBirdTokenizer.from_pretrained("abhinavkulkarni/bigbird-roberta-base-finetuned-squad")
|
||||
>>> model = BigBirdForQuestionAnswering.from_pretrained("abhinavkulkarni/bigbird-roberta-base-finetuned-squad")
|
||||
>>> squad_ds = load_dataset("squad_v2", split="train") # doctest: +IGNORE_RESULT
|
||||
|
||||
>>> # select random article and question
|
||||
>>> LONG_ARTICLE = squad_ds[81514]["context"]
|
||||
>>> QUESTION = squad_ds[81514]["question"]
|
||||
>>> QUESTION
|
||||
'During daytime how high can the temperatures reach?'
|
||||
|
||||
>>> inputs = tokenizer(QUESTION, LONG_ARTICLE, return_tensors="pt")
|
||||
>>> # long article and question input
|
||||
>>> list(inputs["input_ids"].shape)
|
||||
[1, 929]
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(**inputs)
|
||||
|
||||
>>> answer_start_index = outputs.start_logits.argmax()
|
||||
>>> answer_end_index = outputs.end_logits.argmax()
|
||||
>>> predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
|
||||
>>> tokenizer.decode(predict_answer_tokens)
|
||||
'80 °C (176 °F) or more'
|
||||
```
|
||||
|
||||
```python
|
||||
>>> target_start_index, target_end_index = torch.tensor([130]), torch.tensor([132])
|
||||
>>> outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index)
|
||||
>>> loss = outputs.loss
|
||||
>>> round(outputs.loss.item(), 2)
|
||||
7.63
|
||||
```
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
|
@ -6,11 +6,14 @@ docs/source/en/model_doc/t5v1.1.mdx
|
||||
docs/source/en/model_doc/byt5.mdx
|
||||
docs/source/en/model_doc/tapex.mdx
|
||||
src/transformers/generation_utils.py
|
||||
src/transformers/models/albert/modeling_albert.py
|
||||
src/transformers/models/albert/modeling_tf_albert.py
|
||||
src/transformers/models/bart/modeling_bart.py
|
||||
src/transformers/models/beit/modeling_beit.py
|
||||
src/transformers/models/bert/modeling_bert.py
|
||||
src/transformers/models/bert/modeling_tf_bert.py
|
||||
src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py
|
||||
src/transformers/models/big_bird/modeling_big_bird.py
|
||||
src/transformers/models/blenderbot/modeling_blenderbot.py
|
||||
src/transformers/models/blenderbot_small/modeling_blenderbot_small.py
|
||||
src/transformers/models/convnext/modeling_convnext.py
|
||||
|
Loading…
Reference in New Issue
Block a user