mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-23 14:29:01 +06:00

* add sdpa to OPT
* chore: remove redundant whitespace in OPTDecoder class
* fixup
* bug fix
* add sdpa and attention generate test
* fixup
* Refactor OPTAttention forward method for improved readability and maintainability
* undo refactor for _shape and key,val states
* add OPT to doc, fixup didn't find it for some reason
* change order
* change default attn_implemntation in testing to eager
* [run-slow] opt
* change test_eager_matches_sdpa_generate to the one llama
* Update default attention implementation in testing common
* [run-slow] opt
* remove uneeded print
* [run-slow] opt
* refactor model testers to have attn_implementation="eager"
* [run-slow] opt
* convert test_eager_matches_sdpa_generate to opt-350M
* bug fix when creating mask for opt
* [run-slow] opt
* if layer head mask default to eager
* if head mask is not none fall to eager
* [run-slow] opt
* Update src/transformers/models/opt/modeling_opt.py
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Clean up Unpack imports (#33631)
clean up Unpack imports
* Fix DPT /Dinov2 sdpa regression on main (#33660)
* fallback to eager if output attentions.
* fix copies
* handle dependency errors in check_imports (#33622)
* handle dependency errors in check_imports
* change log level to warning
* add back self.max_position_embeddings = config.max_position_embeddings (#33550)
* add back self.max_position_embeddings = config.max_position_embeddings
* fix-copies
* Fix Llava conversion for LlavaQwen2ForCausalLM with Clip vision tower (#33613)
fix llavaqwen2 model conversion
* Uniformize kwargs for Udop processor and update docs (#33628)
* Add optional kwargs and uniformize udop
* cleanup Unpack
* nit Udop
* Generation: deprecate `PreTrainedModel` inheriting from `GenerationMixin` (#33203)
* Enable BNB multi-backend support (#31098)
* enable cpu bnb path
* fix style
* fix code style
* fix 4 bit path
* Update src/transformers/utils/import_utils.py
Co-authored-by: Aarni Koskela <akx@iki.fi>
* add multi backend refactor tests
* fix style
* tweak 4bit quantizer + fix corresponding tests
* tweak 8bit quantizer + *try* fixing corresponding tests
* fix dequant bnb 8bit
* account for Intel CPU in variability of expected outputs
* enable cpu and xpu device map
* further tweaks to account for Intel CPU
* fix autocast to work with both cpu + cuda
* fix comments
* fix comments
* switch to testing_utils.torch_device
* allow for xpu in multi-gpu tests
* fix tests 4bit for CPU NF4
* fix bug with is_torch_xpu_available needing to be called as func
* avoid issue where test reports attr err due to other failure
* fix formatting
* fix typo from resolving of merge conflict
* polish based on last PR review
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
* fix CI
* Update src/transformers/integrations/integration_utils.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Update src/transformers/integrations/integration_utils.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* fix error log
* fix error msg
* add \n in error log
* make quality
* rm bnb cuda restriction in doc
* cpu model don't need dispatch
* fix doc
* fix style
* check cuda avaliable in testing
* fix tests
* Update docs/source/en/model_doc/chameleon.md
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
* Update docs/source/en/model_doc/llava_next.md
Co-authored-by: Aarni Koskela <akx@iki.fi>
* Update tests/quantization/bnb/test_4bit.py
Co-authored-by: Aarni Koskela <akx@iki.fi>
* Update tests/quantization/bnb/test_4bit.py
Co-authored-by: Aarni Koskela <akx@iki.fi>
* fix doc
* fix check multibackends
* fix import sort
* remove check torch in bnb
* docs: update bitsandbytes references with multi-backend info
* docs: fix small mistakes in bnb paragraph
* run formatting
* reveret bnb check
* move bnb multi-backend check to import_utils
* Update src/transformers/utils/import_utils.py
Co-authored-by: Aarni Koskela <akx@iki.fi>
* fix bnb check
* minor fix for bnb
* check lib first
* fix code style
* Revert "run formatting"
This reverts commit ac108c6d6b
.
* fix format
* give warning when bnb version is low and no cuda found]
* fix device assignment check to be multi-device capable
* address akx feedback on get_avlbl_dev fn
* revert partially, as we don't want the function that public, as docs would be too much (enforced)
---------
Co-authored-by: Aarni Koskela <akx@iki.fi>
Co-authored-by: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Fix error string after refactoring into get_chat_template (#33652)
* Fix error string after refactoring into get_chat_template
* Take suggestion from CR
Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
---------
Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
* uniformize git processor (#33668)
* uniformize git processor
* update doctring
* Modular `transformers`: modularity and inheritance for new model additions (#33248)
* update exampel
* update
* push the converted diff files for testing and ci
* correct one example
* fix class attributes and docstring
* nits
* oups
* fixed config!
* update
* nitd
* class attributes are not matched against the other, this is missing
* fixed overwriting self.xxx now onto the attributes I think
* partial fix, now order with docstring
* fix docstring order?
* more fixes
* update
* fix missing docstrings!
* examples don't all work yet
* fixup
* nit
* updated
* hick
* update
* delete
* update
* update
* update
* fix
* all default
* no local import
* fix more diff
* some fix related to "safe imports"
* push fixed
* add helper!
* style
* add a check
* all by default
* add the
* update
* FINALLY!
* nit
* fix config dependencies
* man that is it
* fix fix
* update diffs
* fix the last issue
* re-default to all
* alll the fixes
* nice
* fix properties vs setter
* fixup
* updates
* update dependencies
* make sure to install what needs to be installed
* fixup
* quick fix for now
* fix!
* fixup
* update
* update
* updates
* whitespaces
* nit
* fix
* simplify everything, and make it file agnostic (should work for image processors)
* style
* finish fixing all import issues
* fixup
* empty modeling should not be written!
* Add logic to find who depends on what
* update
* cleanup
* update
* update gemma to support positions
* some small nits
* this is the correct docstring for gemma2
* fix merging of docstrings
* update
* fixup
* update
* take doc into account
* styling
* update
* fix hidden activation
* more fixes
* final fixes!
* fixup
* fixup instruct blip video
* update
* fix bugs
* align gemma2 with the rest as well
* updats
* revert
* update
* more reversiom
* grind
* more
* arf
* update
* order will matter
* finish del stuff
* update
* rename to modular
* fixup
* nits
* update makefile
* fixup
* update order of the checks!
* fix
* fix docstring that has a call inside
* fiix conversion check
* style
* add some initial documentation
* update
* update doc
* some fixup
* updates
* yups
* Mostly todo gimme a minut
* update
* fixup
* revert some stuff
* Review docs for the modular transformers (#33472)
Docs
* good update
* fixup
* mmm current updates lead to this code
* okay, this fixes it
* cool
* fixes
* update
* nit
* updates
* nits
* fix doc
* update
* revert bad changes
* update
* updates
* proper update
* update
* update?
* up
* update
* cool
* nits
* nits
* bon bon
* fix
* ?
* minimise changes
* update
* update
* update
* updates?
* fixed gemma2
* kind of a hack
* nits
* update
* remove `diffs` in favor of `modular`
* fix make fix copies
---------
Co-authored-by: Lysandre Debut <hi@lysand.re>
* Fix CIs post merging modular transformers (#33681)
update
* Fixed docstring for cohere model regarding unavailability of prune_he… (#33253)
* Fixed docstring for cohere model regarding unavailability of prune_head() methods
The docstring mentions that cohere model supports prune_heads() methods. I have fixed the docstring by explicitly mentioning that it doesn't support that functionality.
* Update src/transformers/models/cohere/modeling_cohere.py
---------
Co-authored-by: Lysandre Debut <hi@lysand.re>
* Generation tests: update imagegpt input name, remove unused functions (#33663)
* Improve Error Messaging for Flash Attention 2 on CPU (#33655)
Update flash-attn error message on CPU
Rebased to latest branch
* Gemma2: fix config initialization (`cache_implementation`) (#33684)
* Fix ByteLevel alphabet missing when Sequence pretokenizer is used (#33556)
* Fix ByteLevel alphabet missing when Sequence pretokenizer is used
* Fixed formatting with `ruff`.
* Uniformize kwargs for image-text-to-text processors (#32544)
* uniformize FUYU processor kwargs
* Uniformize instructblip processor kwargs
* Fix processor kwargs and tests Fuyu, InstructBlip, Kosmos2
* Uniformize llava_next processor
* Fix save_load test for processor with chat_template only as extra init args
* Fix import Unpack
* Fix Fuyu Processor import
* Fix FuyuProcessor import
* Fix FuyuProcessor
* Add defaults for specific kwargs kosmos2
* Fix Udop to return BatchFeature instead of BatchEncoding and uniformize kwargs
* Add tests processor Udop
* remove Copied from in processing Udop as change of input orders caused by BatchEncoding -> BatchFeature
* Fix overwrite tests kwargs processors
* Add warnings and BC for changes in processor inputs order, change docs, add BC for text_pair as arg for Udop
* Fix processing test fuyu
* remove unnecessary pad_token check in instructblip ProcessorTest
* Fix BC tests and cleanup
* FIx imports fuyu
* Uniformize Pix2Struct
* Fix wrong name for FuyuProcessorKwargs
* Fix slow tests reversed inputs align fuyu llava-next, change udop warning
* Fix wrong logging import udop
* Add check images text input order
* Fix copies
* change text pair handling when positional arg
* rebase on main, fix imports in test_processing_common
* remove optional args and udop uniformization from this PR
* fix failing tests
* remove unnecessary test, fix processing utils and test processing common
* cleanup Unpack
* cleanup
* fix conflict grounding dino
* 🚨🚨 Setting default behavior of assisted decoding (#33657)
* tests: fix pytorch tensor placement errors (#33485)
This commit fixes the following errors:
* Fix "expected all tensors to be on the same device" error
* Fix "can't convert device type tensor to numpy"
According to pytorch documentation torch.Tensor.numpy(force=False)
performs conversion only if tensor is on CPU (plus few other restrictions)
which is not the case. For our case we need force=True since we just
need a data and don't care about tensors coherency.
Fixes: #33517
See: https://pytorch.org/docs/2.4/generated/torch.Tensor.numpy.html
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
* bump tokenizers, fix added tokens fast (#32535)
* update based on tokenizers release
* update
* nits
* update
* revert re addition
* don't break that yet
* fmt
* revert unwanted
* update tokenizers version
* update dep table
* update
* update in conversion script as well
* some fix
* revert
* fully revert
* fix training
* remove set trace
* fixup
* update
* update
* [Pixtral] Improve docs, rename model (#33491)
* Improve docs, rename model
* Fix style
* Update repo id
* fix code quality after merge
* HFQuantizer implementation for compressed-tensors library (#31704)
* Add compressed-tensors HFQuantizer implementation
* flag serializable as False
* run
* revive lines deleted by ruff
* fixes to load+save from sparseml, edit config to quantization_config, and load back
* address satrat comment
* compressed_tensors to compressed-tensors and revert back is_serializable
* rename quant_method from sparseml to compressed-tensors
* tests
* edit tests
* clean up tests
* make style
* cleanup
* cleanup
* add test skip for when compressed tensors is not installed
* remove pydantic import + style
* delay torch import in test
* initial docs
* update main init for compressed tensors config
* make fix-copies
* docstring
* remove fill_docstring
* Apply suggestions from code review
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
* review comments
* review comments
* comments - suppress warnings on state dict load, tests, fixes
* bug-fix - remove unnecessary call to apply quant lifecycle
* run_compressed compatability
* revert changes not needed for compression
* no longer need unexpected keys fn
* unexpected keys not needed either
* Apply suggestions from code review
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
* add to_diff_dict
* update docs and expand testing
* Update _toctree.yml with compressed-tensors
* Update src/transformers/utils/quantization_config.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* update doc
* add note about saving a loaded model
---------
Co-authored-by: George Ohashi <george@neuralmagic.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Sara Adkins <sara@neuralmagic.com>
Co-authored-by: Sara Adkins <sara.adkins65@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Dipika Sikka <ds3822@columbia.edu>
Co-authored-by: Dipika <dipikasikka1@gmail.com>
* update model card for opt
* add batch size to inference table
* [slow-run] opt
* [run-slow] opt
---------
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
Co-authored-by: Avishai Elmakies <avishai.elma@cs.huji.ac.il>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
Co-authored-by: chengchengpei <5881383+chengchengpei@users.noreply.github.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: jiqing-feng <jiqing.feng@intel.com>
Co-authored-by: Aarni Koskela <akx@iki.fi>
Co-authored-by: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Tibor Reiss <75096465+tibor-reiss@users.noreply.github.com>
Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
Co-authored-by: Lysandre Debut <hi@lysand.re>
Co-authored-by: Muhammad Naufil <m.naufil1@gmail.com>
Co-authored-by: sizhky <yyeshr@gmail.com>
Co-authored-by: Umar Butler <umar@umar.au>
Co-authored-by: Jonathan Mamou <jonathan.mamou@intel.com>
Co-authored-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
Co-authored-by: Benjamin Fineran <bfineran@users.noreply.github.com>
Co-authored-by: George Ohashi <george@neuralmagic.com>
Co-authored-by: Sara Adkins <sara@neuralmagic.com>
Co-authored-by: Sara Adkins <sara.adkins65@gmail.com>
Co-authored-by: Dipika Sikka <ds3822@columbia.edu>
Co-authored-by: Dipika <dipikasikka1@gmail.com>
648 lines
26 KiB
Python
648 lines
26 KiB
Python
# coding=utf-8
|
||
# Copyright 2021, The HuggingFace Inc. team. All rights reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
"""Testing suite for the PyTorch OPT model."""
|
||
|
||
import copy
|
||
import tempfile
|
||
import unittest
|
||
|
||
import timeout_decorator # noqa
|
||
|
||
from transformers import OPTConfig, is_torch_available
|
||
from transformers.testing_utils import (
|
||
require_torch,
|
||
require_torch_accelerator,
|
||
require_torch_fp16,
|
||
require_torch_sdpa,
|
||
slow,
|
||
torch_device,
|
||
)
|
||
|
||
from ...generation.test_utils import GenerationTesterMixin
|
||
from ...test_configuration_common import ConfigTester
|
||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||
|
||
|
||
if is_torch_available():
|
||
import torch
|
||
|
||
from transformers import (
|
||
GPT2Tokenizer,
|
||
OPTForCausalLM,
|
||
OPTForQuestionAnswering,
|
||
OPTForSequenceClassification,
|
||
OPTModel,
|
||
)
|
||
|
||
|
||
def prepare_opt_inputs_dict(
|
||
config,
|
||
input_ids,
|
||
decoder_input_ids=None,
|
||
attention_mask=None,
|
||
decoder_attention_mask=None,
|
||
head_mask=None,
|
||
decoder_head_mask=None,
|
||
):
|
||
if attention_mask is None:
|
||
attention_mask = input_ids.ne(config.pad_token_id)
|
||
return {
|
||
"input_ids": input_ids,
|
||
"attention_mask": attention_mask,
|
||
"head_mask": head_mask,
|
||
}
|
||
|
||
|
||
class OPTModelTester:
|
||
def __init__(
|
||
self,
|
||
parent,
|
||
batch_size=13,
|
||
seq_length=7,
|
||
is_training=True,
|
||
use_labels=False,
|
||
vocab_size=99,
|
||
hidden_size=16,
|
||
num_hidden_layers=2,
|
||
num_attention_heads=4,
|
||
intermediate_size=4,
|
||
hidden_act="gelu",
|
||
hidden_dropout_prob=0.1,
|
||
attention_probs_dropout_prob=0.1,
|
||
max_position_embeddings=20,
|
||
eos_token_id=2,
|
||
pad_token_id=1,
|
||
bos_token_id=0,
|
||
embed_dim=16,
|
||
num_labels=3,
|
||
word_embed_proj_dim=16,
|
||
type_sequence_label_size=2,
|
||
attn_implementation="eager",
|
||
):
|
||
self.parent = parent
|
||
self.batch_size = batch_size
|
||
self.seq_length = seq_length
|
||
self.is_training = is_training
|
||
self.use_labels = use_labels
|
||
self.vocab_size = vocab_size
|
||
self.hidden_size = hidden_size
|
||
self.num_hidden_layers = num_hidden_layers
|
||
self.num_attention_heads = num_attention_heads
|
||
self.intermediate_size = intermediate_size
|
||
self.hidden_act = hidden_act
|
||
self.hidden_dropout_prob = hidden_dropout_prob
|
||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||
self.max_position_embeddings = max_position_embeddings
|
||
self.eos_token_id = eos_token_id
|
||
self.pad_token_id = pad_token_id
|
||
self.bos_token_id = bos_token_id
|
||
self.embed_dim = embed_dim
|
||
self.num_labels = num_labels
|
||
self.type_sequence_label_size = type_sequence_label_size
|
||
self.word_embed_proj_dim = word_embed_proj_dim
|
||
self.is_encoder_decoder = False
|
||
self.attn_implementation = attn_implementation
|
||
|
||
def prepare_config_and_inputs(self):
|
||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
|
||
3,
|
||
)
|
||
input_ids[:, -1] = self.eos_token_id # Eos Token
|
||
|
||
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||
|
||
config = self.get_config()
|
||
inputs_dict = prepare_opt_inputs_dict(config, input_ids, decoder_input_ids)
|
||
return config, inputs_dict
|
||
|
||
def get_config(self):
|
||
return OPTConfig(
|
||
vocab_size=self.vocab_size,
|
||
hidden_size=self.hidden_size,
|
||
num_hidden_layers=self.num_hidden_layers,
|
||
num_attention_heads=self.num_attention_heads,
|
||
ffn_dim=self.intermediate_size,
|
||
dropout=self.hidden_dropout_prob,
|
||
attention_dropout=self.attention_probs_dropout_prob,
|
||
max_position_embeddings=self.max_position_embeddings,
|
||
eos_token_id=self.eos_token_id,
|
||
bos_token_id=self.bos_token_id,
|
||
pad_token_id=self.pad_token_id,
|
||
embed_dim=self.embed_dim,
|
||
is_encoder_decoder=False,
|
||
word_embed_proj_dim=self.word_embed_proj_dim,
|
||
attn_implementation=self.attn_implementation,
|
||
)
|
||
|
||
def get_pipeline_config(self):
|
||
config = self.get_config()
|
||
config.max_position_embeddings = 100
|
||
return config
|
||
|
||
def prepare_config_and_inputs_for_common(self):
|
||
config, inputs_dict = self.prepare_config_and_inputs()
|
||
return config, inputs_dict
|
||
|
||
def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
|
||
model = OPTModel(config=config).to(torch_device).eval()
|
||
|
||
input_ids = inputs_dict["input_ids"]
|
||
attention_mask = inputs_dict["attention_mask"]
|
||
head_mask = inputs_dict["head_mask"]
|
||
|
||
# first forward pass
|
||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||
|
||
output, past_key_values = outputs.to_tuple()
|
||
|
||
# create hypothetical multiple next token and extent to next_input_ids
|
||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||
next_attn_mask = ids_tensor((self.batch_size, 3), 2)
|
||
|
||
# append to next input_ids and
|
||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||
next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1)
|
||
|
||
output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]
|
||
output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[
|
||
"last_hidden_state"
|
||
]
|
||
|
||
# select random slice
|
||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
||
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
||
|
||
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||
|
||
# test that outputs are equal for slice
|
||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||
|
||
# test no attention_mask works
|
||
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
||
_, past_key_values = outputs.to_tuple()
|
||
output_from_no_past = model(next_input_ids)["last_hidden_state"]
|
||
|
||
output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
|
||
|
||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
||
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
||
# test that outputs are equal for slice
|
||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||
|
||
|
||
@require_torch
|
||
class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||
all_model_classes = (
|
||
(OPTModel, OPTForCausalLM, OPTForSequenceClassification, OPTForQuestionAnswering)
|
||
if is_torch_available()
|
||
else ()
|
||
)
|
||
all_generative_model_classes = (OPTForCausalLM,) if is_torch_available() else ()
|
||
pipeline_model_mapping = (
|
||
{
|
||
"feature-extraction": OPTModel,
|
||
"question-answering": OPTForQuestionAnswering,
|
||
"text-classification": OPTForSequenceClassification,
|
||
"text-generation": OPTForCausalLM,
|
||
"zero-shot": OPTForSequenceClassification,
|
||
}
|
||
if is_torch_available()
|
||
else {}
|
||
)
|
||
is_encoder_decoder = False
|
||
fx_compatible = True
|
||
test_pruning = False
|
||
test_missing_keys = False
|
||
|
||
# TODO: Fix the failed tests
|
||
def is_pipeline_test_to_skip(
|
||
self,
|
||
pipeline_test_case_name,
|
||
config_class,
|
||
model_architecture,
|
||
tokenizer_name,
|
||
image_processor_name,
|
||
feature_extractor_name,
|
||
processor_name,
|
||
):
|
||
if (
|
||
pipeline_test_case_name == "QAPipelineTests"
|
||
and tokenizer_name is not None
|
||
and not tokenizer_name.endswith("Fast")
|
||
):
|
||
# `QAPipelineTests` fails for a few models when the slower tokenizer are used.
|
||
# (The slower tokenizers were never used for pipeline tests before the pipeline testing rework)
|
||
# TODO: check (and possibly fix) the `QAPipelineTests` with slower tokenizer
|
||
return True
|
||
|
||
return False
|
||
|
||
def setUp(self):
|
||
self.model_tester = OPTModelTester(self)
|
||
self.config_tester = ConfigTester(self, config_class=OPTConfig)
|
||
|
||
def test_config(self):
|
||
self.config_tester.run_common_tests()
|
||
|
||
def test_save_load_strict(self):
|
||
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
||
for model_class in self.all_model_classes:
|
||
model = model_class(config)
|
||
|
||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||
model.save_pretrained(tmpdirname)
|
||
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
|
||
self.assertEqual(info["missing_keys"], [])
|
||
|
||
def test_decoder_model_past_with_large_inputs(self):
|
||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
||
|
||
def test_inputs_embeds(self):
|
||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||
|
||
for model_class in (OPTModel,):
|
||
model = model_class(config)
|
||
model.to(torch_device)
|
||
model.eval()
|
||
|
||
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||
|
||
if not self.is_encoder_decoder:
|
||
input_ids = inputs["input_ids"]
|
||
del inputs["input_ids"]
|
||
else:
|
||
encoder_input_ids = inputs["input_ids"]
|
||
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
|
||
del inputs["input_ids"]
|
||
inputs.pop("decoder_input_ids", None)
|
||
|
||
wte = model.get_input_embeddings()
|
||
if not self.is_encoder_decoder:
|
||
inputs["inputs_embeds"] = wte(input_ids)
|
||
else:
|
||
inputs["inputs_embeds"] = wte(encoder_input_ids)
|
||
inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
|
||
|
||
with torch.no_grad():
|
||
model(**inputs)[0]
|
||
|
||
@require_torch_fp16
|
||
def test_generate_fp16(self):
|
||
config, input_dict = self.model_tester.prepare_config_and_inputs()
|
||
input_ids = input_dict["input_ids"]
|
||
attention_mask = input_ids.ne(1).to(torch_device)
|
||
model = OPTForCausalLM(config).eval().to(torch_device)
|
||
model.half()
|
||
model.generate(input_ids, attention_mask=attention_mask)
|
||
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
|
||
|
||
def test_opt_sequence_classification_model(self):
|
||
config, input_dict = self.model_tester.prepare_config_and_inputs()
|
||
config.num_labels = 3
|
||
input_ids = input_dict["input_ids"]
|
||
attention_mask = input_ids.ne(1).to(torch_device)
|
||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
||
model = OPTForSequenceClassification(config)
|
||
model.to(torch_device)
|
||
model.eval()
|
||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||
|
||
def test_opt_sequence_classification_model_for_multi_label(self):
|
||
config, input_dict = self.model_tester.prepare_config_and_inputs()
|
||
config.num_labels = 3
|
||
config.problem_type = "multi_label_classification"
|
||
input_ids = input_dict["input_ids"]
|
||
attention_mask = input_ids.ne(1).to(torch_device)
|
||
sequence_labels = ids_tensor(
|
||
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
|
||
).to(torch.float)
|
||
model = OPTForSequenceClassification(config)
|
||
model.to(torch_device)
|
||
model.eval()
|
||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||
|
||
@require_torch_sdpa
|
||
@slow
|
||
def test_eager_matches_sdpa_generate(self):
|
||
"""
|
||
Overwritting the common test as the test is flaky on tiny models
|
||
"""
|
||
max_new_tokens = 30
|
||
|
||
tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350M")
|
||
|
||
texts = [
|
||
"hi here's a longer context, getting longer and",
|
||
"Hello this is a very long sentence my friend, very long for real",
|
||
"Today I am in Paris and",
|
||
]
|
||
|
||
model_sdpa = OPTForCausalLM.from_pretrained(
|
||
"facebook/opt-350M",
|
||
torch_dtype=torch.float16,
|
||
low_cpu_mem_usage=True,
|
||
attn_implementation="sdpa",
|
||
).to(torch_device)
|
||
|
||
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
|
||
|
||
model_eager = OPTForCausalLM.from_pretrained(
|
||
"facebook/opt-350M",
|
||
torch_dtype=torch.float16,
|
||
low_cpu_mem_usage=True,
|
||
attn_implementation="eager",
|
||
).to(torch_device)
|
||
|
||
self.assertTrue(model_eager.config._attn_implementation == "eager")
|
||
|
||
for _, submodule in model_eager.named_modules():
|
||
if "SdpaAttention" in submodule.__class__.__name__:
|
||
raise ValueError("The eager model should not have SDPA attention layers")
|
||
|
||
has_sdpa = False
|
||
for _, submodule in model_sdpa.named_modules():
|
||
if "SdpaAttention" in submodule.__class__.__name__:
|
||
has_sdpa = True
|
||
break
|
||
if not has_sdpa:
|
||
raise ValueError("The SDPA model should have SDPA attention layers")
|
||
|
||
for padding_side in ["left", "right"]:
|
||
tokenizer.padding_side = padding_side
|
||
tokenizer.pad_token = tokenizer.eos_token
|
||
|
||
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device)
|
||
|
||
res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
|
||
res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
|
||
|
||
with self.subTest(f"{padding_side}"):
|
||
torch.testing.assert_close(
|
||
res_eager,
|
||
res_sdpa,
|
||
msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}",
|
||
)
|
||
|
||
@unittest.skip(reason="Does not work on the tiny model as we keep hitting edge cases.")
|
||
def test_model_parallelism(self):
|
||
super().test_model_parallelism()
|
||
|
||
|
||
def assert_tensors_close(a, b, atol=1e-12, prefix=""):
|
||
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
|
||
if a is None and b is None:
|
||
return True
|
||
try:
|
||
if torch.allclose(a, b, atol=atol):
|
||
return True
|
||
raise
|
||
except Exception:
|
||
pct_different = (torch.gt((a - b).abs(), atol)).float().mean().item()
|
||
if a.numel() > 100:
|
||
msg = f"tensor values are {pct_different:.1%} percent different."
|
||
else:
|
||
msg = f"{a} != {b}"
|
||
if prefix:
|
||
msg = prefix + ": " + msg
|
||
raise AssertionError(msg)
|
||
|
||
|
||
def _long_tensor(tok_lst):
|
||
return torch.tensor(tok_lst, dtype=torch.long, device=torch_device)
|
||
|
||
|
||
@require_torch
|
||
class OPTModelIntegrationTests(unittest.TestCase):
|
||
@slow
|
||
def test_inference_no_head(self):
|
||
model = OPTModel.from_pretrained("facebook/opt-350m").to(torch_device)
|
||
input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
|
||
|
||
with torch.no_grad():
|
||
output = model(input_ids=input_ids).last_hidden_state
|
||
|
||
expected_shape = torch.Size((1, 11, 512))
|
||
self.assertEqual(output.shape, expected_shape)
|
||
# expected value works for CPU, as well as GPU (with TF32 disabled)
|
||
expected_slice = torch.tensor(
|
||
[
|
||
[-0.28726277, -1.9241608, -0.3058734],
|
||
[-1.2737825, -0.13332152, -0.18766522],
|
||
[0.41159445, 0.1191957, -1.3107123],
|
||
],
|
||
device=torch_device,
|
||
)
|
||
assert_tensors_close(output[0, :3, :3], expected_slice, atol=5e-5)
|
||
|
||
|
||
@require_torch
|
||
@slow
|
||
class OPTEmbeddingsTest(unittest.TestCase):
|
||
def setUp(self):
|
||
super().setUp()
|
||
self.path_model = "facebook/opt-350m"
|
||
|
||
def test_load_model(self):
|
||
try:
|
||
_ = OPTForCausalLM.from_pretrained(self.path_model)
|
||
except BaseException:
|
||
self.fail("Failed loading model")
|
||
|
||
def test_logits(self):
|
||
model = OPTForCausalLM.from_pretrained(self.path_model)
|
||
model = model.eval()
|
||
tokenizer = GPT2Tokenizer.from_pretrained(self.path_model)
|
||
|
||
prompts = [
|
||
"Today is a beautiful day and I want to",
|
||
"In the city of",
|
||
"Paris is the capital of France and",
|
||
"Computers and mobile phones have taken",
|
||
]
|
||
# verify that prompt without BOS token is identical to Metaseq -> add_special_tokens=False
|
||
inputs = tokenizer(prompts, return_tensors="pt", padding=True, add_special_tokens=False)
|
||
logits = model(inputs.input_ids, attention_mask=inputs.attention_mask)[0].mean(dim=-1)
|
||
# logits_meta = torch.load(self.path_logits_meta)
|
||
logits_meta = torch.Tensor(
|
||
[
|
||
[1.3851, -13.8923, -10.5229, -10.7533, -0.2309, -10.2384, -0.5365, -9.0947, -5.1670],
|
||
[-4.7073, -10.6276, -3.9415, -21.5242, -0.2822, -0.2822, -0.2822, -0.2822, -0.2822],
|
||
[0.6247, -3.4229, -8.9179, -1.4297, -14.1650, 1.4146, -9.0218, -0.2703, -0.2703],
|
||
[6.4783, -1.9913, -10.7926, -2.3336, 1.5092, -0.9974, -6.8213, 1.3477, 1.3477],
|
||
]
|
||
)
|
||
assert torch.allclose(logits, logits_meta, atol=1e-4)
|
||
|
||
|
||
@slow
|
||
class OPTGenerationTest(unittest.TestCase):
|
||
@property
|
||
def prompts(self):
|
||
return [
|
||
"Today is a beautiful day and I want",
|
||
"In the city of",
|
||
"Paris is the capital of France and",
|
||
"Computers and mobile phones have taken",
|
||
]
|
||
|
||
def test_generation_pre_attn_layer_norm(self):
|
||
model_id = "facebook/opt-125m"
|
||
|
||
EXPECTED_OUTPUTS = [
|
||
"Today is a beautiful day and I want to",
|
||
"In the city of New York, the city",
|
||
"Paris is the capital of France and the capital",
|
||
"Computers and mobile phones have taken over the",
|
||
]
|
||
|
||
predicted_outputs = []
|
||
tokenizer = GPT2Tokenizer.from_pretrained(model_id)
|
||
model = OPTForCausalLM.from_pretrained(model_id)
|
||
|
||
for prompt in self.prompts:
|
||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||
|
||
generated_ids = model.generate(input_ids, max_length=10)
|
||
|
||
generated_string = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||
predicted_outputs += generated_string
|
||
|
||
self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
|
||
|
||
def test_batch_generation(self):
|
||
model_id = "facebook/opt-350m"
|
||
|
||
tokenizer = GPT2Tokenizer.from_pretrained(model_id)
|
||
model = OPTForCausalLM.from_pretrained(model_id)
|
||
model.to(torch_device)
|
||
|
||
tokenizer.padding_side = "left"
|
||
|
||
# use different length sentences to test batching
|
||
sentences = [
|
||
"Hello, my dog is a little",
|
||
"Today, I",
|
||
]
|
||
|
||
inputs = tokenizer(sentences, return_tensors="pt", padding=True)
|
||
input_ids = inputs["input_ids"].to(torch_device)
|
||
|
||
outputs = model.generate(
|
||
input_ids=input_ids,
|
||
attention_mask=inputs["attention_mask"].to(torch_device),
|
||
)
|
||
|
||
inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
|
||
output_non_padded = model.generate(input_ids=inputs_non_padded)
|
||
|
||
num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item()
|
||
inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device)
|
||
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
|
||
|
||
batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
|
||
padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
|
||
|
||
expected_output_sentence = [
|
||
"Hello, my dog is a little bit of a dork.\nI'm a little bit",
|
||
"Today, I was in the middle of a conversation with a friend about the",
|
||
]
|
||
self.assertListEqual(expected_output_sentence, batch_out_sentence)
|
||
self.assertListEqual(batch_out_sentence, [non_padded_sentence, padded_sentence])
|
||
|
||
def test_generation_post_attn_layer_norm(self):
|
||
model_id = "facebook/opt-350m"
|
||
|
||
EXPECTED_OUTPUTS = [
|
||
"Today is a beautiful day and I want to",
|
||
"In the city of San Francisco, the city",
|
||
"Paris is the capital of France and the capital",
|
||
"Computers and mobile phones have taken over the",
|
||
]
|
||
|
||
predicted_outputs = []
|
||
tokenizer = GPT2Tokenizer.from_pretrained(model_id)
|
||
model = OPTForCausalLM.from_pretrained(model_id)
|
||
|
||
for prompt in self.prompts:
|
||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||
|
||
generated_ids = model.generate(input_ids, max_length=10)
|
||
|
||
generated_string = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||
predicted_outputs += generated_string
|
||
|
||
self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
|
||
|
||
@require_torch_accelerator
|
||
@require_torch_fp16
|
||
def test_batched_nan_fp16(self):
|
||
# a bug manifested starting at models facebook/opt-1.3 and larger when running batched generations,
|
||
# therefore not using a tiny model, but the smallest model the problem was seen with which is opt-1.3b.
|
||
# please refer to this github thread: https://github.com/huggingface/transformers/pull/17437 for more details
|
||
model_name = "facebook/opt-1.3b"
|
||
tokenizer = GPT2Tokenizer.from_pretrained(model_name, use_fast=False, padding_side="left")
|
||
|
||
model = OPTForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, use_cache=True).to(torch_device)
|
||
model = model.eval()
|
||
|
||
batch = tokenizer(["Who are you?", "Joe Biden is the president of"], padding=True, return_tensors="pt")
|
||
|
||
input_ids = batch["input_ids"].to(torch_device)
|
||
attention_mask = batch["attention_mask"].to(torch_device)
|
||
|
||
with torch.no_grad():
|
||
outputs = model(input_ids, attention_mask=attention_mask)
|
||
self.assertFalse(
|
||
torch.isnan(outputs.logits[0]).any().item()
|
||
) # the first logits could contain NaNs if it fails
|
||
|
||
@slow
|
||
def test_contrastive_search_opt(self):
|
||
article = (
|
||
"A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the "
|
||
"Statue of Liberty.\nHuman: Where do you live?\nStatue: New York City.\nHuman: How long have you lived "
|
||
"there?"
|
||
)
|
||
|
||
opt_tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-1.3b")
|
||
opt_model = OPTForCausalLM.from_pretrained("facebook/opt-1.3b").to(torch_device)
|
||
input_ids = opt_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||
|
||
outputs = opt_model.generate(input_ids, penalty_alpha=0.6, top_k=5, max_length=256)
|
||
generated_text = opt_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||
|
||
self.assertListEqual(
|
||
generated_text,
|
||
[
|
||
"A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I "
|
||
"am the Statue of Liberty.\nHuman: Where do you live?\nStatue: New York City.\nHuman: How long have "
|
||
"you lived there?\nStatue: A hundred years.\nHuman: And you’re from what country?\nStatue: The United "
|
||
"States of America.\nHuman: Why did you come to America?\nStatue: I came to escape the tyranny of my "
|
||
"country.\nHuman: What tyranny?\nStatue: They didn’t let me speak my mind.\nHuman: What was your "
|
||
"country?\nStatue: It was a country of immigrants.\nHuman: Who were the immigrants?\nStatue: They "
|
||
"were from all over the world.\nHuman: What language did they speak?\nStatue: French, Spanish, "
|
||
"Italian, German, English—you name it.\nHuman: And where did they come from?\nStatue: They came from "
|
||
"every country in the world.\nHuman: And you were born in what country?\nStatue: I was born in "
|
||
"France.\nHuman: And your parents were French?\nStatue"
|
||
],
|
||
)
|