mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
Add Doc Test for GPT-J (#16507)
* Required the values GPTJ unfortunately cannot run the model =) * Added the file to the doc tests * Run Fixup and Style * Fixed with the test versions of gptj. Ran Style and Fixup. * Trigger ci * A Minor Change to License * Fixed spacing added to the benchmark_utils. Then refactored tests to const variables. * Removed strings that were included as default parameters anyways. Co-authored-by: ArEnSc <xx.mike.chung.xx@gmail.com>
This commit is contained in:
parent
12bfa97a43
commit
06b4aac9eb
@ -36,10 +36,19 @@ from .configuration_gptj import GPTJConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "EleutherAI/gpt-j-6B"
|
||||
_CHECKPOINT_FOR_DOC = "hf-internal-testing/tiny-random-gptj"
|
||||
_CONFIG_FOR_DOC = "GPTJConfig"
|
||||
_TOKENIZER_FOR_DOC = "GPT2Tokenizer"
|
||||
|
||||
_CHECKPOINT_FOR_QA = "ydshieh/tiny-random-gptj-for-question-answering"
|
||||
_QA_EXPECTED_OUTPUT = "' was Jim Henson?Jim Henson was a n'"
|
||||
_QA_EXPECTED_LOSS = 3.13
|
||||
|
||||
_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ydshieh/tiny-random-gptj-for-sequence-classification"
|
||||
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'"
|
||||
_SEQ_CLASS_EXPECTED_LOSS = 0.76
|
||||
|
||||
|
||||
GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"EleutherAI/gpt-j-6B",
|
||||
# See all GPT-J models at https://huggingface.co/models?filter=gptj
|
||||
@ -892,9 +901,11 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
|
||||
@add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
|
||||
output_type=SequenceClassifierOutputWithPast,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
|
||||
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
@ -1017,9 +1028,11 @@ class GPTJForQuestionAnswering(GPTJPreTrainedModel):
|
||||
@add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_QA,
|
||||
output_type=QuestionAnsweringModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
expected_output=_QA_EXPECTED_OUTPUT,
|
||||
expected_loss=_QA_EXPECTED_LOSS,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
|
@ -19,6 +19,7 @@ src/transformers/models/deit/modeling_deit.py
|
||||
src/transformers/models/dpt/modeling_dpt.py
|
||||
src/transformers/models/glpn/modeling_glpn.py
|
||||
src/transformers/models/gpt2/modeling_gpt2.py
|
||||
src/transformers/models/gptj/modeling_gptj.py
|
||||
src/transformers/models/hubert/modeling_hubert.py
|
||||
src/transformers/models/marian/modeling_marian.py
|
||||
src/transformers/models/mbart/modeling_mbart.py
|
||||
|
Loading…
Reference in New Issue
Block a user