Fix model integration ci (#26322)

* fix wav2vec2

* nit

* stash

* one more file to update

* fix byt5

* vocab size is 256, don't change that!

* use other revision

* test persimon in smaller size

* style

* tests

* nits

* update add tokens from pretrained

* test tokenization

* nits

* potential fnet fix?

* more nits

* nits

* correct test

* assert close

* udpate

* ouch

* fix it

* some more nits

* FINALLU

* use `adept` checkpoints

* more adept checkpoints

* that was invlved!
This commit is contained in:
Arthur 2023-10-02 13:55:46 +02:00 committed by GitHub
parent 6824461f2a
commit 63864e057f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 38 additions and 18 deletions

View File

@ -104,7 +104,7 @@ class ByT5Tokenizer(PreTrainedTokenizer):
return self._utf_vocab_size
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size + self.offset)}
vocab.update(self.added_tokens_encoder)
return vocab

View File

@ -21,7 +21,7 @@ from ...utils import logging
logger = logging.get_logger(__name__)
PERSIMMON_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"ArthurZ/persimmon-8b-base": "https://huggingface.co/ArthurZ/persimmon-8b-base/resolve/main/config.json",
"adept/persimmon-8b-base": "https://huggingface.co/adept/persimmon-8b-base/resolve/main/config.json",
}
@ -30,7 +30,7 @@ class PersimmonConfig(PretrainedConfig):
This is the configuration class to store the configuration of a [`PersimmonModel`]. It is used to instantiate an
Persimmon model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the
[ArthurZ/persimmon-8b-base](https://huggingface.co/ArthurZ/persimmon-8b-base).
[adept/persimmon-8b-base](https://huggingface.co/adept/persimmon-8b-base).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.

View File

@ -789,8 +789,8 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
```python
>>> from transformers import AutoTokenizer, PersimmonForCausalLM
>>> model = PersimmonForCausalLM.from_pretrained("ArthurZ/persimmon-8b-base")
>>> tokenizer = AutoTokenizer.from_pretrained("ArthurZ/persimmon-8b-base")
>>> model = PersimmonForCausalLM.from_pretrained("adept/persimmon-8b-base")
>>> tokenizer = AutoTokenizer.from_pretrained("adept/persimmon-8b-base")
>>> prompt = "human: Hey, what should I eat for dinner?"
>>> inputs = tokenizer(prompt, return_tensors="pt")

View File

@ -232,9 +232,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
# make sure that tokens made of several
# characters are not split at tokenization
for token in self.encoder.keys():
if len(token) > 1:
self.unique_no_split_tokens.append(token)
self.add_tokens([token for token in self.encoder.keys() if len(token) > 1])
@property
def word_delimiter_token(self) -> str:

View File

@ -2209,7 +2209,6 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
" it is kept for forward compatibility, but it is recommended to update your `tokenizer_config.json` by uploading it again."
" You will see the new `added_tokens_decoder` attribute that will store the relevant information."
)
# begin legacy: read the added_tokens_file and update kwargs with special_tokens_map if modified
if special_tokens_map_file is not None:
with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle:
@ -2221,6 +2220,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
continue
if isinstance(value, dict):
value = AddedToken(**value)
init_kwargs[key] = value
elif key == "additional_special_tokens" and isinstance(value, list):
for token in value:
token = AddedToken(**token) if isinstance(token, dict) else token
@ -2233,8 +2233,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
with open(added_tokens_file, encoding="utf-8") as added_tokens_handle:
added_tok_encoder = json.load(added_tokens_handle)
# legacy: we have to init with (rstrip=True, lstrip=True)
strip = True if "Fast" not in cls.__name__ else False
added_tokens_decoder = {
index: AddedToken(token, rstrip=True, lstrip=True) for token, index in added_tok_encoder.items()
index: AddedToken(token, rstrip=strip, lstrip=strip) for token, index in added_tok_encoder.items()
}
# end legacy

View File

@ -532,8 +532,6 @@ class FNetModelIntegrationTest(unittest.TestCase):
@slow
@require_tokenizers
def test_inference_long_sentence(self):
model = FNetForMaskedLM.from_pretrained("google/fnet-base")
model.to(torch_device)
tokenizer = FNetTokenizerFast.from_pretrained("google/fnet-base")
inputs = tokenizer(
@ -543,8 +541,15 @@ class FNetModelIntegrationTest(unittest.TestCase):
padding="max_length",
max_length=512,
)
# fmt: off
torch.testing.assert_allclose(inputs["input_ids"], torch.tensor([[4, 13, 283, 2479, 106, 8, 6, 845, 5, 168, 65, 367, 6, 845, 5, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3, 3, 3, 3, 3, 3, 3, 3, 3, 3,3]]))
# fmt: on
inputs = {k: v.to(torch_device) for k, v in inputs.items()}
model = FNetForMaskedLM.from_pretrained("google/fnet-base")
model.to(torch_device)
logits = model(**inputs).logits
predictions_mask_1 = tokenizer.decode(logits[0, 6].topk(5).indices)
predictions_mask_2 = tokenizer.decode(logits[0, 12].topk(5).indices)

View File

@ -503,7 +503,11 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, unittest.TestCase):
class IdeficsModelIntegrationTest(TestCasePlus):
@cached_property
def default_processor(self):
return IdeficsProcessor.from_pretrained("HuggingFaceM4/idefics-9b") if is_vision_available() else None
return (
IdeficsProcessor.from_pretrained("HuggingFaceM4/idefics-9b", revision="refs/pr/11")
if is_vision_available()
else None
)
@require_bitsandbytes
@slow

View File

@ -29,7 +29,14 @@ from transformers import (
InstructBlipQFormerConfig,
InstructBlipVisionConfig,
)
from transformers.testing_utils import require_bitsandbytes, require_torch, require_vision, slow, torch_device
from transformers.testing_utils import (
require_accelerate,
require_bitsandbytes,
require_torch,
require_vision,
slow,
torch_device,
)
from transformers.utils import is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
@ -522,6 +529,7 @@ def prepare_img():
@slow
class InstructBlipModelIntegrationTest(unittest.TestCase):
@require_bitsandbytes
@require_accelerate
def test_inference_vicuna_7b(self):
processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
model = InstructBlipForConditionalGeneration.from_pretrained(

View File

@ -386,11 +386,13 @@ class PersimmonIntegrationTest(unittest.TestCase):
@slow
def test_model_8b_chat_logits(self):
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
model = PersimmonForCausalLM.from_pretrained("ArthurZ/persimmon-8b-chat", device_map="auto")
model = PersimmonForCausalLM.from_pretrained(
"adept/persimmon-8b-chat", device_map="auto", torch_dtype=torch.float16
)
out = model(torch.tensor([input_ids])).logits
EXPECTED_MEAN = torch.tensor(
[[-11.2879, -11.2628, -11.2498, -11.2534, -11.2676, -11.2638, -11.2501, -11.2431]], dtype=torch.float32
[[-11.2879, -11.2628, -11.2498, -11.2534, -11.2676, -11.2638, -11.2501, -11.2431]], dtype=torch.float16
)
torch.testing.assert_close(out.cpu().mean(-1), EXPECTED_MEAN, atol=1e-4, rtol=1e-4)
# fmt: off
@ -403,9 +405,11 @@ class PersimmonIntegrationTest(unittest.TestCase):
def test_model_8b_chat_greedy_generation(self):
EXPECTED_TEXT_COMPLETION = """human: Simply put, the theory of relativity states that?\n\nadept: The theory of relativity states that the laws of physics are the same for all observers, regardless of their relative motion."""
prompt = "human: Simply put, the theory of relativity states that?\n\nadept:"
tokenizer = AutoTokenizer.from_pretrained("ArthurZ/persimmon-8b-chat", use_fast=False)
tokenizer = AutoTokenizer.from_pretrained("adept/persimmon-8b-chat", use_fast=False)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(torch_device)
model = PersimmonForCausalLM.from_pretrained("ArthurZ/persimmon-8b-chat").to(torch_device)
model = PersimmonForCausalLM.from_pretrained("adept/persimmon-8b-chat", torch_dtype=torch.float16).to(
torch_device
)
# greedy generation outputs
generated_ids = model.generate(input_ids, max_new_tokens=64)