OPT - fix docstring and improve tests slighly (#17228)

* correct some stuff

* fix doc tests

* make style
This commit is contained in:
Patrick von Platen 2022-05-13 15:14:50 +02:00 committed by GitHub
parent dfc76018c1
commit 18d6b356c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 62 additions and 53 deletions

View File

@ -37,12 +37,12 @@ from .configuration_opt import OPTConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "" _CHECKPOINT_FOR_DOC = "facebook/opt-350m"
_CONFIG_FOR_DOC = "OPTConfig" _CONFIG_FOR_DOC = "OPTConfig"
_TOKENIZER_FOR_DOC = "GPT2Tokenizer" _TOKENIZER_FOR_DOC = "GPT2Tokenizer"
# Base model docstring # Base model docstring
_EXPECTED_OUTPUT_SHAPE = [1, 8, 768] _EXPECTED_OUTPUT_SHAPE = [1, 8, 1024]
OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [ OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
@ -424,25 +424,6 @@ class OPTPreTrainedModel(PreTrainedModel):
module.gradient_checkpointing = value module.gradient_checkpointing = value
OPT_GENERATION_EXAMPLE = r"""
Generation example:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> model = OPTForCausalLM.from_pretrained("ArthurZ/opt-350m")
>>> tokenizer = GPT2Tokenizer.from_pretrained("patrickvonplaten/opt_gpt2_tokenizer")
>>> TEXTS_TO_GENERATE = "Hey, are you consciours? Can you talk to me?" "Hi there, my name is Barack"
>>> inputs = tokenizer([TEXTS_TO_GENERATE], max_length=1024, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs["input_ids"], num_beams=2, min_length=0, max_length=20)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
'I'm not conscious.<\s>'
```
"""
OPT_INPUTS_DOCSTRING = r""" OPT_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@ -933,19 +914,18 @@ class OPTForCausalLM(OPTPreTrainedModel):
Example: Example:
```python ```python
>>> from transformers import OPTTokenizer, OPTForCausalLM >>> from transformers import GPT2Tokenizer, OPTForCausalLM
# this needs fixing
>>> tokenizer = OPTTokenizer.from_pretrained("patrickvonplaten/opt_gpt2_tokenizer") >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
>>> model = OPTForCausalLM.from_pretrained("ArthurZ/opt-350m") >>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits >>> prompt = "Hey, are you consciours? Can you talk to me?"
>>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size] >>> inputs = tokenizer(prompt, return_tensors="pt")
>>> list(logits.shape) == expected_shape
True >>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
```""" ```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions

View File

@ -21,7 +21,7 @@ import unittest
import timeout_decorator # noqa import timeout_decorator # noqa
from transformers import OPTConfig, is_torch_available, pipeline from transformers import OPTConfig, is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from transformers.utils import cached_property from transformers.utils import cached_property
@ -330,33 +330,61 @@ class OPTEmbeddingsTest(unittest.TestCase):
assert torch.allclose(logits, logits_meta, atol=1e-4) assert torch.allclose(logits, logits_meta, atol=1e-4)
@require_tokenizers
@slow @slow
class OPTGenerationTest(unittest.TestCase): class OPTGenerationTest(unittest.TestCase):
def setUp(self): @property
super().setUp() def prompts(self):
self.all_model_path = ["facebook/opt-125m", "facebook/opt-350m"] return [
def test_generation(self):
prompts = [
"Today is a beautiful day and I want to", "Today is a beautiful day and I want to",
"In the city of", "In the city of",
"Paris is the capital of France and", "Paris is the capital of France and",
"Computers and mobile phones have taken", "Computers and mobile phones have taken",
] ]
NEXT_TOKENS = [3392, 764, 5, 81]
GEN_OUTPUT = []
tokenizer = GPT2Tokenizer.from_pretrained("patrickvonplaten/opt_gpt2_tokenizer") def test_generation_pre_attn_layer_norm(self):
for model in self.all_model_path: model_id = "facebook/opt-125m"
model = OPTForCausalLM.from_pretrained(self.path_model)
model = model.eval()
model.config.eos_token_id = tokenizer.eos_token_id
gen = pipeline("text-generation", model=model, tokenizer=tokenizer, return_tensors=True) EXPECTED_OUTPUTS = [
"Today is a beautiful day and I want to thank",
"In the city of Rome Canaver Canaver Canaver Canaver",
"Paris is the capital of France and Parisdylib",
"Computers and mobile phones have taken precedence over",
]
for prompt in prompts: predicted_outputs = []
len_input_sentence = len(tokenizer.tokenize(prompt)) tokenizer = GPT2Tokenizer.from_pretrained(model_id)
predicted_next_token = gen(prompt)[0]["generated_token_ids"][len_input_sentence] model = OPTForCausalLM.from_pretrained(model_id)
GEN_OUTPUT.append(predicted_next_token)
self.assertListEqual(GEN_OUTPUT, NEXT_TOKENS) for prompt in self.prompts:
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
generated_ids = model.generate(input_ids, max_length=10)
generated_string = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
predicted_outputs += generated_string
self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
def test_generation_post_attn_layer_norm(self):
model_id = "facebook/opt-350m"
EXPECTED_OUTPUTS = [
"Today is a beautiful day and I want to share",
"In the city of San Francisco, the city",
"Paris is the capital of France and the capital",
"Computers and mobile phones have taken over the",
]
predicted_outputs = []
tokenizer = GPT2Tokenizer.from_pretrained(model_id)
model = OPTForCausalLM.from_pretrained(model_id)
for prompt in self.prompts:
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
generated_ids = model.generate(input_ids, max_length=10)
generated_string = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
predicted_outputs += generated_string
self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)

View File

@ -35,6 +35,7 @@ src/transformers/models/marian/modeling_marian.py
src/transformers/models/mbart/modeling_mbart.py src/transformers/models/mbart/modeling_mbart.py
src/transformers/models/mobilebert/modeling_mobilebert.py src/transformers/models/mobilebert/modeling_mobilebert.py
src/transformers/models/mobilebert/modeling_tf_mobilebert.py src/transformers/models/mobilebert/modeling_tf_mobilebert.py
src/transformers/models/opt/modeling_opt.py
src/transformers/models/pegasus/modeling_pegasus.py src/transformers/models/pegasus/modeling_pegasus.py
src/transformers/models/plbart/modeling_plbart.py src/transformers/models/plbart/modeling_plbart.py
src/transformers/models/poolformer/modeling_poolformer.py src/transformers/models/poolformer/modeling_poolformer.py