diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 18e3aa6303f..d045603038f 100755 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -36,10 +36,19 @@ from .configuration_gptj import GPTJConfig logger = logging.get_logger(__name__) -_CHECKPOINT_FOR_DOC = "EleutherAI/gpt-j-6B" +_CHECKPOINT_FOR_DOC = "hf-internal-testing/tiny-random-gptj" _CONFIG_FOR_DOC = "GPTJConfig" _TOKENIZER_FOR_DOC = "GPT2Tokenizer" +_CHECKPOINT_FOR_QA = "ydshieh/tiny-random-gptj-for-question-answering" +_QA_EXPECTED_OUTPUT = "' was Jim Henson?Jim Henson was a n'" +_QA_EXPECTED_LOSS = 3.13 + +_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ydshieh/tiny-random-gptj-for-sequence-classification" +_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'" +_SEQ_CLASS_EXPECTED_LOSS = 0.76 + + GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST = [ "EleutherAI/gpt-j-6B", # See all GPT-J models at https://huggingface.co/models?filter=gptj @@ -892,9 +901,11 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel): @add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION, output_type=SequenceClassifierOutputWithPast, config_class=_CONFIG_FOR_DOC, + expected_output=_SEQ_CLASS_EXPECTED_OUTPUT, + expected_loss=_SEQ_CLASS_EXPECTED_LOSS, ) def forward( self, @@ -1017,9 +1028,11 @@ class GPTJForQuestionAnswering(GPTJPreTrainedModel): @add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, - checkpoint=_CHECKPOINT_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_QA, output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC, + expected_output=_QA_EXPECTED_OUTPUT, + expected_loss=_QA_EXPECTED_LOSS, ) def forward( self, diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt index 170076244cd..8e96b81e6d7 100644 --- a/utils/documentation_tests.txt +++ b/utils/documentation_tests.txt @@ -19,6 +19,7 @@ src/transformers/models/deit/modeling_deit.py src/transformers/models/dpt/modeling_dpt.py src/transformers/models/glpn/modeling_glpn.py src/transformers/models/gpt2/modeling_gpt2.py +src/transformers/models/gptj/modeling_gptj.py src/transformers/models/hubert/modeling_hubert.py src/transformers/models/marian/modeling_marian.py src/transformers/models/mbart/modeling_mbart.py