From 0d1cff119578b04f52249670f0fc58abadfc832d Mon Sep 17 00:00:00 2001 From: Minh Chien Vu <31467068+vumichien@users.noreply.github.com> Date: Sat, 23 Apr 2022 01:07:16 +0900 Subject: [PATCH] Add doc tests for Albert and Bigbird (#16774) * Add doctest BERT * make fixup * fix typo * change checkpoints * make fixup * define doctest output value, update doctest for mobilebert * solve fix-copies * update QA target start index and end index * change checkpoint for docs and reuse defined variable * Update src/transformers/models/bert/modeling_tf_bert.py Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> * make fixup * Add Doctest for Albert and Bigbird * make fixup * overwrite examples for Albert and Bigbird * Apply suggestions from code review Co-authored-by: Patrick von Platen * update longer examples for Bigbird * using examples from squad_v2 * print out example text * change name token-classification-big-bird checkpoint to random Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Co-authored-by: Patrick von Platen --- .../models/albert/modeling_albert.py | 58 ++++-- .../models/albert/modeling_tf_albert.py | 57 ++++-- .../models/big_bird/modeling_big_bird.py | 175 ++++++++++++++---- utils/documentation_tests.txt | 3 + 4 files changed, 230 insertions(+), 63 deletions(-) diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index 241e3550811..ff68e467882 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -801,9 +801,8 @@ class AlbertForPreTraining(AlbertPreTrainedModel): >>> tokenizer = AlbertTokenizer.from_pretrained("albert-base-v2") >>> model = AlbertForPreTraining.from_pretrained("albert-base-v2") - >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze( - ... 0 - >>> ) # Batch size 1 + >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) + >>> # Batch size 1 >>> outputs = model(input_ids) >>> prediction_logits = outputs.prediction_logits @@ -914,12 +913,7 @@ class AlbertForMaskedLM(AlbertPreTrainedModel): return self.albert.embeddings.word_embeddings @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=MaskedLMOutput, - config_class=_CONFIG_FOR_DOC, - ) + @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids=None, @@ -938,6 +932,37 @@ class AlbertForMaskedLM(AlbertPreTrainedModel): Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + + Returns: + + Example: + + ```python + >>> import torch + >>> from transformers import AlbertTokenizer, AlbertForMaskedLM + + >>> tokenizer = AlbertTokenizer.from_pretrained("albert-base-v2") + >>> model = AlbertForMaskedLM.from_pretrained("albert-base-v2") + + >>> # add mask_token + >>> inputs = tokenizer("The capital of [MASK] is Paris.", return_tensors="pt") + >>> with torch.no_grad(): + ... logits = model(**inputs).logits + + >>> # retrieve index of [MASK] + >>> mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0] + >>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1) + >>> tokenizer.decode(predicted_token_id) + 'france' + ``` + + ```python + >>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"] + >>> labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100) + >>> outputs = model(**inputs, labels=labels) + >>> round(outputs.loss.item(), 2) + 0.81 + ``` """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -996,9 +1021,11 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel): @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, + checkpoint="textattack/albert-base-v2-imdb", output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC, + expected_output="'LABEL_1'", + expected_loss=0.12, ) def forward( self, @@ -1103,9 +1130,12 @@ class AlbertForTokenClassification(AlbertPreTrainedModel): @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, + checkpoint="vumichien/tiny-albert", output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC, + expected_output="['LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_0', 'LABEL_1', 'LABEL_0', 'LABEL_1', 'LABEL_1', " + "'LABEL_0', 'LABEL_1', 'LABEL_0', 'LABEL_0', 'LABEL_1', 'LABEL_1']", + expected_loss=0.66, ) def forward( self, @@ -1184,9 +1214,13 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel): @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, + checkpoint="twmkn9/albert-base-v2-squad2", output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC, + qa_target_start_index=12, + qa_target_end_index=13, + expected_output="'a nice puppet'", + expected_loss=7.36, ) def forward( self, diff --git a/src/transformers/models/albert/modeling_tf_albert.py b/src/transformers/models/albert/modeling_tf_albert.py index ae325558cd7..93af948e210 100644 --- a/src/transformers/models/albert/modeling_tf_albert.py +++ b/src/transformers/models/albert/modeling_tf_albert.py @@ -865,9 +865,8 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel, TFAlbertPreTrainingLoss): >>> tokenizer = AlbertTokenizer.from_pretrained("albert-base-v2") >>> model = TFAlbertForPreTraining.from_pretrained("albert-base-v2") - >>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[ - ... None, : - >>> ] # Batch size 1 + >>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] + >>> # Batch size 1 >>> outputs = model(input_ids) >>> prediction_logits = outputs.prediction_logits @@ -954,12 +953,7 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss) @unpack_inputs @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TFMaskedLMOutput, - config_class=_CONFIG_FOR_DOC, - ) + @replace_return_docstrings(output_type=TFMaskedLMOutput, config_class=_CONFIG_FOR_DOC) def call( self, input_ids: Optional[TFModelInputType] = None, @@ -979,6 +973,36 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss) Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + + Returns: + + Example: + + ```python + >>> import tensorflow as tf + >>> from transformers import AlbertTokenizer, TFAlbertForMaskedLM + + >>> tokenizer = AlbertTokenizer.from_pretrained("albert-base-v2") + >>> model = TFAlbertForMaskedLM.from_pretrained("albert-base-v2") + + >>> # add mask_token + >>> inputs = tokenizer(f"The capital of [MASK] is Paris.", return_tensors="tf") + >>> logits = model(**inputs).logits + + >>> # retrieve index of [MASK] + >>> mask_token_index = tf.where(inputs.input_ids == tokenizer.mask_token_id)[0][1] + >>> predicted_token_id = tf.math.argmax(logits[0, mask_token_index], axis=-1) + >>> tokenizer.decode(predicted_token_id) + 'france' + ``` + + ```python + >>> labels = tokenizer("The capital of France is Paris.", return_tensors="tf")["input_ids"] + >>> labels = tf.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100) + >>> outputs = model(**inputs, labels=labels) + >>> round(float(outputs.loss), 2) + 0.81 + ``` """ outputs = self.albert( input_ids=input_ids, @@ -1043,9 +1067,11 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, + checkpoint="vumichien/albert-base-v2-imdb", output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC, + expected_output="'LABEL_1'", + expected_loss=0.12, ) def call( self, @@ -1136,9 +1162,12 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, + checkpoint="vumichien/tiny-albert", output_type=TFTokenClassifierOutput, config_class=_CONFIG_FOR_DOC, + expected_output="['LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_0', 'LABEL_1', 'LABEL_0', 'LABEL_1', 'LABEL_1', " + "'LABEL_0', 'LABEL_1', 'LABEL_0', 'LABEL_0', 'LABEL_1', 'LABEL_1']", + expected_loss=0.66, ) def call( self, @@ -1220,9 +1249,13 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, + checkpoint="vumichien/albert-base-v2-squad2", output_type=TFQuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC, + qa_target_start_index=12, + qa_target_end_index=13, + expected_output="'a nice puppet'", + expected_loss=7.36, ) def call( self, diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index f255a363297..563e8631902 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -2392,12 +2392,7 @@ class BigBirdForMaskedLM(BigBirdPreTrainedModel): self.cls.predictions.decoder = new_embeddings @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=MaskedLMOutput, - config_class=_CONFIG_FOR_DOC, - ) + @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, @@ -2418,6 +2413,49 @@ class BigBirdForMaskedLM(BigBirdPreTrainedModel): Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> import torch + >>> from transformers import BigBirdTokenizer, BigBirdForMaskedLM + >>> from datasets import load_dataset + + >>> tokenizer = BigBirdTokenizer.from_pretrained("google/bigbird-roberta-base") + >>> model = BigBirdForMaskedLM.from_pretrained("google/bigbird-roberta-base") + >>> squad_ds = load_dataset("squad_v2", split="train") # doctest: +IGNORE_RESULT + + >>> # select random long article + >>> LONG_ARTICLE_TARGET = squad_ds[81514]["context"] + >>> # select random sentence + >>> LONG_ARTICLE_TARGET[332:398] + 'the highest values are very close to the theoretical maximum value' + + >>> # add mask_token + >>> LONG_ARTICLE_TO_MASK = LONG_ARTICLE_TARGET.replace("maximum", "[MASK]") + >>> inputs = tokenizer(LONG_ARTICLE_TO_MASK, return_tensors="pt") + >>> # long article input + >>> list(inputs["input_ids"].shape) + [1, 919] + + >>> with torch.no_grad(): + ... logits = model(**inputs).logits + >>> # retrieve index of [MASK] + >>> mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0] + >>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1) + >>> tokenizer.decode(predicted_token_id) + 'maximum' + ``` + + ```python + >>> labels = tokenizer(LONG_ARTICLE_TARGET, return_tensors="pt")["input_ids"] + >>> labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100) + >>> outputs = model(**inputs, labels=labels) + >>> round(outputs.loss.item(), 2) + 1.08 + ``` """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -2496,7 +2534,12 @@ class BigBirdForCausalLM(BigBirdPreTrainedModel): self.cls.predictions.decoder = new_embeddings @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) def forward( self, input_ids: torch.LongTensor = None, @@ -2536,25 +2579,7 @@ class BigBirdForCausalLM(BigBirdPreTrainedModel): use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - - Returns: - - Example: - - ```python - >>> from transformers import BigBirdTokenizer, BigBirdForCausalLM, BigBirdConfig - >>> import torch - - >>> tokenizer = BigBirdTokenizer.from_pretrained("google/bigbird-roberta-base") - >>> config = BigBirdConfig.from_pretrained("google/bigbird-roberta-base") - >>> config.is_decoder = True - >>> model = BigBirdForCausalLM.from_pretrained("google/bigbird-roberta-base", config=config) - - >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") - >>> outputs = model(**inputs) - - >>> prediction_logits = outputs.logits - ```""" + """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.bert( @@ -2662,12 +2687,7 @@ class BigBirdForSequenceClassification(BigBirdPreTrainedModel): self.post_init() @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=SequenceClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) + @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, @@ -2686,6 +2706,43 @@ class BigBirdForSequenceClassification(BigBirdPreTrainedModel): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Example: + + ```python + >>> import torch + >>> from transformers import BigBirdTokenizer, BigBirdForSequenceClassification + >>> from datasets import load_dataset + + >>> tokenizer = BigBirdTokenizer.from_pretrained("l-yohai/bigbird-roberta-base-mnli") + >>> model = BigBirdForSequenceClassification.from_pretrained("l-yohai/bigbird-roberta-base-mnli") + >>> squad_ds = load_dataset("squad_v2", split="train") # doctest: +IGNORE_RESULT + + >>> LONG_ARTICLE = squad_ds[81514]["context"] + >>> inputs = tokenizer(LONG_ARTICLE, return_tensors="pt") + >>> # long input article + >>> list(inputs["input_ids"].shape) + [1, 919] + + >>> with torch.no_grad(): + ... logits = model(**inputs).logits + >>> predicted_class_id = logits.argmax().item() + >>> model.config.id2label[predicted_class_id] + 'LABEL_0' + ``` + + ```python + >>> num_labels = len(model.config.id2label) + >>> model = BigBirdForSequenceClassification.from_pretrained( + ... "l-yohai/bigbird-roberta-base-mnli", num_labels=num_labels + ... ) + >>> labels = torch.tensor(1) + >>> loss = model(**inputs, labels=labels).loss + >>> round(loss.item(), 2) + 1.13 + ``` """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -2858,9 +2915,12 @@ class BigBirdForTokenClassification(BigBirdPreTrainedModel): @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, + checkpoint="vumichien/token-classification-bigbird-roberta-base-random", output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC, + expected_output="['LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', " + "'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1']", + expected_loss=0.54, ) def forward( self, @@ -2955,12 +3015,7 @@ class BigBirdForQuestionAnswering(BigBirdPreTrainedModel): self.post_init() @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - processor_class=_TOKENIZER_FOR_DOC, - checkpoint="google/bigbird-base-trivia-itc", - output_type=BigBirdForQuestionAnsweringModelOutput, - config_class=_CONFIG_FOR_DOC, - ) + @replace_return_docstrings(output_type=BigBirdForQuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, @@ -2985,6 +3040,48 @@ class BigBirdForQuestionAnswering(BigBirdPreTrainedModel): Labels for position (index) of the end of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. + + Returns: + + Example: + + ```python + >>> import torch + >>> from transformers import BigBirdTokenizer, BigBirdForQuestionAnswering + >>> from datasets import load_dataset + + >>> tokenizer = BigBirdTokenizer.from_pretrained("abhinavkulkarni/bigbird-roberta-base-finetuned-squad") + >>> model = BigBirdForQuestionAnswering.from_pretrained("abhinavkulkarni/bigbird-roberta-base-finetuned-squad") + >>> squad_ds = load_dataset("squad_v2", split="train") # doctest: +IGNORE_RESULT + + >>> # select random article and question + >>> LONG_ARTICLE = squad_ds[81514]["context"] + >>> QUESTION = squad_ds[81514]["question"] + >>> QUESTION + 'During daytime how high can the temperatures reach?' + + >>> inputs = tokenizer(QUESTION, LONG_ARTICLE, return_tensors="pt") + >>> # long article and question input + >>> list(inputs["input_ids"].shape) + [1, 929] + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> answer_start_index = outputs.start_logits.argmax() + >>> answer_end_index = outputs.end_logits.argmax() + >>> predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1] + >>> tokenizer.decode(predict_answer_tokens) + '80 °C (176 °F) or more' + ``` + + ```python + >>> target_start_index, target_end_index = torch.tensor([130]), torch.tensor([132]) + >>> outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index) + >>> loss = outputs.loss + >>> round(outputs.loss.item(), 2) + 7.63 + ``` """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt index a2ee77c0488..e420642a94f 100644 --- a/utils/documentation_tests.txt +++ b/utils/documentation_tests.txt @@ -6,11 +6,14 @@ docs/source/en/model_doc/t5v1.1.mdx docs/source/en/model_doc/byt5.mdx docs/source/en/model_doc/tapex.mdx src/transformers/generation_utils.py +src/transformers/models/albert/modeling_albert.py +src/transformers/models/albert/modeling_tf_albert.py src/transformers/models/bart/modeling_bart.py src/transformers/models/beit/modeling_beit.py src/transformers/models/bert/modeling_bert.py src/transformers/models/bert/modeling_tf_bert.py src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +src/transformers/models/big_bird/modeling_big_bird.py src/transformers/models/blenderbot/modeling_blenderbot.py src/transformers/models/blenderbot_small/modeling_blenderbot_small.py src/transformers/models/convnext/modeling_convnext.py