mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-15 10:38:23 +06:00

* add sdpa to OPT
* chore: remove redundant whitespace in OPTDecoder class
* fixup
* bug fix
* add sdpa and attention generate test
* fixup
* Refactor OPTAttention forward method for improved readability and maintainability
* undo refactor for _shape and key,val states
* add OPT to doc, fixup didn't find it for some reason
* change order
* change default attn_implemntation in testing to eager
* [run-slow] opt
* change test_eager_matches_sdpa_generate to the one llama
* Update default attention implementation in testing common
* [run-slow] opt
* remove uneeded print
* [run-slow] opt
* refactor model testers to have attn_implementation="eager"
* [run-slow] opt
* convert test_eager_matches_sdpa_generate to opt-350M
* bug fix when creating mask for opt
* [run-slow] opt
* if layer head mask default to eager
* if head mask is not none fall to eager
* [run-slow] opt
* Update src/transformers/models/opt/modeling_opt.py
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* Clean up Unpack imports (#33631)
clean up Unpack imports
* Fix DPT /Dinov2 sdpa regression on main (#33660)
* fallback to eager if output attentions.
* fix copies
* handle dependency errors in check_imports (#33622)
* handle dependency errors in check_imports
* change log level to warning
* add back self.max_position_embeddings = config.max_position_embeddings (#33550)
* add back self.max_position_embeddings = config.max_position_embeddings
* fix-copies
* Fix Llava conversion for LlavaQwen2ForCausalLM with Clip vision tower (#33613)
fix llavaqwen2 model conversion
* Uniformize kwargs for Udop processor and update docs (#33628)
* Add optional kwargs and uniformize udop
* cleanup Unpack
* nit Udop
* Generation: deprecate `PreTrainedModel` inheriting from `GenerationMixin` (#33203)
* Enable BNB multi-backend support (#31098)
* enable cpu bnb path
* fix style
* fix code style
* fix 4 bit path
* Update src/transformers/utils/import_utils.py
Co-authored-by: Aarni Koskela <akx@iki.fi>
* add multi backend refactor tests
* fix style
* tweak 4bit quantizer + fix corresponding tests
* tweak 8bit quantizer + *try* fixing corresponding tests
* fix dequant bnb 8bit
* account for Intel CPU in variability of expected outputs
* enable cpu and xpu device map
* further tweaks to account for Intel CPU
* fix autocast to work with both cpu + cuda
* fix comments
* fix comments
* switch to testing_utils.torch_device
* allow for xpu in multi-gpu tests
* fix tests 4bit for CPU NF4
* fix bug with is_torch_xpu_available needing to be called as func
* avoid issue where test reports attr err due to other failure
* fix formatting
* fix typo from resolving of merge conflict
* polish based on last PR review
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
* fix CI
* Update src/transformers/integrations/integration_utils.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Update src/transformers/integrations/integration_utils.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* fix error log
* fix error msg
* add \n in error log
* make quality
* rm bnb cuda restriction in doc
* cpu model don't need dispatch
* fix doc
* fix style
* check cuda avaliable in testing
* fix tests
* Update docs/source/en/model_doc/chameleon.md
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
* Update docs/source/en/model_doc/llava_next.md
Co-authored-by: Aarni Koskela <akx@iki.fi>
* Update tests/quantization/bnb/test_4bit.py
Co-authored-by: Aarni Koskela <akx@iki.fi>
* Update tests/quantization/bnb/test_4bit.py
Co-authored-by: Aarni Koskela <akx@iki.fi>
* fix doc
* fix check multibackends
* fix import sort
* remove check torch in bnb
* docs: update bitsandbytes references with multi-backend info
* docs: fix small mistakes in bnb paragraph
* run formatting
* reveret bnb check
* move bnb multi-backend check to import_utils
* Update src/transformers/utils/import_utils.py
Co-authored-by: Aarni Koskela <akx@iki.fi>
* fix bnb check
* minor fix for bnb
* check lib first
* fix code style
* Revert "run formatting"
This reverts commit ac108c6d6b
.
* fix format
* give warning when bnb version is low and no cuda found]
* fix device assignment check to be multi-device capable
* address akx feedback on get_avlbl_dev fn
* revert partially, as we don't want the function that public, as docs would be too much (enforced)
---------
Co-authored-by: Aarni Koskela <akx@iki.fi>
Co-authored-by: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Fix error string after refactoring into get_chat_template (#33652)
* Fix error string after refactoring into get_chat_template
* Take suggestion from CR
Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
---------
Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
* uniformize git processor (#33668)
* uniformize git processor
* update doctring
* Modular `transformers`: modularity and inheritance for new model additions (#33248)
* update exampel
* update
* push the converted diff files for testing and ci
* correct one example
* fix class attributes and docstring
* nits
* oups
* fixed config!
* update
* nitd
* class attributes are not matched against the other, this is missing
* fixed overwriting self.xxx now onto the attributes I think
* partial fix, now order with docstring
* fix docstring order?
* more fixes
* update
* fix missing docstrings!
* examples don't all work yet
* fixup
* nit
* updated
* hick
* update
* delete
* update
* update
* update
* fix
* all default
* no local import
* fix more diff
* some fix related to "safe imports"
* push fixed
* add helper!
* style
* add a check
* all by default
* add the
* update
* FINALLY!
* nit
* fix config dependencies
* man that is it
* fix fix
* update diffs
* fix the last issue
* re-default to all
* alll the fixes
* nice
* fix properties vs setter
* fixup
* updates
* update dependencies
* make sure to install what needs to be installed
* fixup
* quick fix for now
* fix!
* fixup
* update
* update
* updates
* whitespaces
* nit
* fix
* simplify everything, and make it file agnostic (should work for image processors)
* style
* finish fixing all import issues
* fixup
* empty modeling should not be written!
* Add logic to find who depends on what
* update
* cleanup
* update
* update gemma to support positions
* some small nits
* this is the correct docstring for gemma2
* fix merging of docstrings
* update
* fixup
* update
* take doc into account
* styling
* update
* fix hidden activation
* more fixes
* final fixes!
* fixup
* fixup instruct blip video
* update
* fix bugs
* align gemma2 with the rest as well
* updats
* revert
* update
* more reversiom
* grind
* more
* arf
* update
* order will matter
* finish del stuff
* update
* rename to modular
* fixup
* nits
* update makefile
* fixup
* update order of the checks!
* fix
* fix docstring that has a call inside
* fiix conversion check
* style
* add some initial documentation
* update
* update doc
* some fixup
* updates
* yups
* Mostly todo gimme a minut
* update
* fixup
* revert some stuff
* Review docs for the modular transformers (#33472)
Docs
* good update
* fixup
* mmm current updates lead to this code
* okay, this fixes it
* cool
* fixes
* update
* nit
* updates
* nits
* fix doc
* update
* revert bad changes
* update
* updates
* proper update
* update
* update?
* up
* update
* cool
* nits
* nits
* bon bon
* fix
* ?
* minimise changes
* update
* update
* update
* updates?
* fixed gemma2
* kind of a hack
* nits
* update
* remove `diffs` in favor of `modular`
* fix make fix copies
---------
Co-authored-by: Lysandre Debut <hi@lysand.re>
* Fix CIs post merging modular transformers (#33681)
update
* Fixed docstring for cohere model regarding unavailability of prune_he… (#33253)
* Fixed docstring for cohere model regarding unavailability of prune_head() methods
The docstring mentions that cohere model supports prune_heads() methods. I have fixed the docstring by explicitly mentioning that it doesn't support that functionality.
* Update src/transformers/models/cohere/modeling_cohere.py
---------
Co-authored-by: Lysandre Debut <hi@lysand.re>
* Generation tests: update imagegpt input name, remove unused functions (#33663)
* Improve Error Messaging for Flash Attention 2 on CPU (#33655)
Update flash-attn error message on CPU
Rebased to latest branch
* Gemma2: fix config initialization (`cache_implementation`) (#33684)
* Fix ByteLevel alphabet missing when Sequence pretokenizer is used (#33556)
* Fix ByteLevel alphabet missing when Sequence pretokenizer is used
* Fixed formatting with `ruff`.
* Uniformize kwargs for image-text-to-text processors (#32544)
* uniformize FUYU processor kwargs
* Uniformize instructblip processor kwargs
* Fix processor kwargs and tests Fuyu, InstructBlip, Kosmos2
* Uniformize llava_next processor
* Fix save_load test for processor with chat_template only as extra init args
* Fix import Unpack
* Fix Fuyu Processor import
* Fix FuyuProcessor import
* Fix FuyuProcessor
* Add defaults for specific kwargs kosmos2
* Fix Udop to return BatchFeature instead of BatchEncoding and uniformize kwargs
* Add tests processor Udop
* remove Copied from in processing Udop as change of input orders caused by BatchEncoding -> BatchFeature
* Fix overwrite tests kwargs processors
* Add warnings and BC for changes in processor inputs order, change docs, add BC for text_pair as arg for Udop
* Fix processing test fuyu
* remove unnecessary pad_token check in instructblip ProcessorTest
* Fix BC tests and cleanup
* FIx imports fuyu
* Uniformize Pix2Struct
* Fix wrong name for FuyuProcessorKwargs
* Fix slow tests reversed inputs align fuyu llava-next, change udop warning
* Fix wrong logging import udop
* Add check images text input order
* Fix copies
* change text pair handling when positional arg
* rebase on main, fix imports in test_processing_common
* remove optional args and udop uniformization from this PR
* fix failing tests
* remove unnecessary test, fix processing utils and test processing common
* cleanup Unpack
* cleanup
* fix conflict grounding dino
* 🚨🚨 Setting default behavior of assisted decoding (#33657)
* tests: fix pytorch tensor placement errors (#33485)
This commit fixes the following errors:
* Fix "expected all tensors to be on the same device" error
* Fix "can't convert device type tensor to numpy"
According to pytorch documentation torch.Tensor.numpy(force=False)
performs conversion only if tensor is on CPU (plus few other restrictions)
which is not the case. For our case we need force=True since we just
need a data and don't care about tensors coherency.
Fixes: #33517
See: https://pytorch.org/docs/2.4/generated/torch.Tensor.numpy.html
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
* bump tokenizers, fix added tokens fast (#32535)
* update based on tokenizers release
* update
* nits
* update
* revert re addition
* don't break that yet
* fmt
* revert unwanted
* update tokenizers version
* update dep table
* update
* update in conversion script as well
* some fix
* revert
* fully revert
* fix training
* remove set trace
* fixup
* update
* update
* [Pixtral] Improve docs, rename model (#33491)
* Improve docs, rename model
* Fix style
* Update repo id
* fix code quality after merge
* HFQuantizer implementation for compressed-tensors library (#31704)
* Add compressed-tensors HFQuantizer implementation
* flag serializable as False
* run
* revive lines deleted by ruff
* fixes to load+save from sparseml, edit config to quantization_config, and load back
* address satrat comment
* compressed_tensors to compressed-tensors and revert back is_serializable
* rename quant_method from sparseml to compressed-tensors
* tests
* edit tests
* clean up tests
* make style
* cleanup
* cleanup
* add test skip for when compressed tensors is not installed
* remove pydantic import + style
* delay torch import in test
* initial docs
* update main init for compressed tensors config
* make fix-copies
* docstring
* remove fill_docstring
* Apply suggestions from code review
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
* review comments
* review comments
* comments - suppress warnings on state dict load, tests, fixes
* bug-fix - remove unnecessary call to apply quant lifecycle
* run_compressed compatability
* revert changes not needed for compression
* no longer need unexpected keys fn
* unexpected keys not needed either
* Apply suggestions from code review
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
* add to_diff_dict
* update docs and expand testing
* Update _toctree.yml with compressed-tensors
* Update src/transformers/utils/quantization_config.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* update doc
* add note about saving a loaded model
---------
Co-authored-by: George Ohashi <george@neuralmagic.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Sara Adkins <sara@neuralmagic.com>
Co-authored-by: Sara Adkins <sara.adkins65@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Dipika Sikka <ds3822@columbia.edu>
Co-authored-by: Dipika <dipikasikka1@gmail.com>
* update model card for opt
* add batch size to inference table
* [slow-run] opt
* [run-slow] opt
---------
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
Co-authored-by: Avishai Elmakies <avishai.elma@cs.huji.ac.il>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
Co-authored-by: chengchengpei <5881383+chengchengpei@users.noreply.github.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: jiqing-feng <jiqing.feng@intel.com>
Co-authored-by: Aarni Koskela <akx@iki.fi>
Co-authored-by: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Tibor Reiss <75096465+tibor-reiss@users.noreply.github.com>
Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
Co-authored-by: Lysandre Debut <hi@lysand.re>
Co-authored-by: Muhammad Naufil <m.naufil1@gmail.com>
Co-authored-by: sizhky <yyeshr@gmail.com>
Co-authored-by: Umar Butler <umar@umar.au>
Co-authored-by: Jonathan Mamou <jonathan.mamou@intel.com>
Co-authored-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
Co-authored-by: Benjamin Fineran <bfineran@users.noreply.github.com>
Co-authored-by: George Ohashi <george@neuralmagic.com>
Co-authored-by: Sara Adkins <sara@neuralmagic.com>
Co-authored-by: Sara Adkins <sara.adkins65@gmail.com>
Co-authored-by: Dipika Sikka <ds3822@columbia.edu>
Co-authored-by: Dipika <dipikasikka1@gmail.com>
407 lines
16 KiB
Python
407 lines
16 KiB
Python
# coding=utf-8
|
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
import unittest
|
|
|
|
import numpy as np
|
|
|
|
from transformers import OPTConfig, is_tf_available
|
|
from transformers.testing_utils import require_sentencepiece, require_tf, slow
|
|
|
|
from ...test_configuration_common import ConfigTester
|
|
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
|
|
|
|
if is_tf_available():
|
|
import tensorflow as tf
|
|
|
|
from transformers import GPT2Tokenizer, TFOPTForCausalLM, TFOPTModel
|
|
|
|
|
|
def prepare_opt_inputs_dict(config, input_ids, attention_mask=None, head_mask=None):
|
|
if attention_mask is None:
|
|
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
|
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
|
|
|
|
|
@require_tf
|
|
class TFOPTModelTester:
|
|
config_cls = OPTConfig
|
|
config_updates = {}
|
|
hidden_act = "gelu"
|
|
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
batch_size=13,
|
|
seq_length=7,
|
|
is_training=True,
|
|
use_labels=False,
|
|
vocab_size=99,
|
|
hidden_size=16,
|
|
num_hidden_layers=2,
|
|
num_attention_heads=4,
|
|
intermediate_size=4,
|
|
hidden_act="gelu",
|
|
hidden_dropout_prob=0.1,
|
|
attention_probs_dropout_prob=0.1,
|
|
max_position_embeddings=20,
|
|
eos_token_id=2,
|
|
pad_token_id=1,
|
|
bos_token_id=0,
|
|
embed_dim=16,
|
|
word_embed_proj_dim=16,
|
|
attn_implementation="eager",
|
|
):
|
|
self.parent = parent
|
|
self.batch_size = batch_size
|
|
self.seq_length = seq_length
|
|
self.is_training = is_training
|
|
self.use_labels = use_labels
|
|
self.vocab_size = vocab_size
|
|
self.hidden_size = hidden_size
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_attention_heads = num_attention_heads
|
|
self.intermediate_size = intermediate_size
|
|
self.hidden_act = hidden_act
|
|
self.hidden_dropout_prob = hidden_dropout_prob
|
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.eos_token_id = eos_token_id
|
|
self.pad_token_id = pad_token_id
|
|
self.bos_token_id = bos_token_id
|
|
self.embed_dim = embed_dim
|
|
self.word_embed_proj_dim = word_embed_proj_dim
|
|
self.is_encoder_decoder = False
|
|
self.attn_implementation = attn_implementation
|
|
|
|
def prepare_config_and_inputs_for_common(self):
|
|
input_ids = ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size)
|
|
eos_tensor = tf.expand_dims(tf.constant([self.eos_token_id] * self.batch_size), 1)
|
|
input_ids = tf.concat([input_ids, eos_tensor], axis=1)
|
|
|
|
config = self.config_cls(
|
|
vocab_size=self.vocab_size,
|
|
hidden_size=self.hidden_size,
|
|
num_hidden_layers=self.num_hidden_layers,
|
|
num_attention_heads=self.num_attention_heads,
|
|
ffn_dim=self.intermediate_size,
|
|
dropout=self.hidden_dropout_prob,
|
|
attention_dropout=self.attention_probs_dropout_prob,
|
|
max_position_embeddings=self.max_position_embeddings,
|
|
eos_token_id=self.eos_token_id,
|
|
bos_token_id=self.bos_token_id,
|
|
pad_token_id=self.pad_token_id,
|
|
embed_dim=self.embed_dim,
|
|
word_embed_proj_dim=self.word_embed_proj_dim,
|
|
is_encoder_decoder=False,
|
|
attn_implementation=self.attn_implementation,
|
|
**self.config_updates,
|
|
)
|
|
inputs_dict = prepare_opt_inputs_dict(config, input_ids)
|
|
return config, inputs_dict
|
|
|
|
def check_decoder_model_past_large_inputs(self, config, inputs_dict):
|
|
model = TFOPTModel(config=config)
|
|
input_ids = inputs_dict["input_ids"]
|
|
|
|
input_ids = input_ids[:1, :]
|
|
attention_mask = inputs_dict["attention_mask"][:1, :]
|
|
self.batch_size = 1
|
|
|
|
# first forward pass
|
|
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
|
|
|
output, past_key_values = outputs.to_tuple()
|
|
|
|
# create hypothetical next token and extent to next_input_ids
|
|
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
|
next_attn_mask = tf.cast(ids_tensor((self.batch_size, 3), 2), tf.int8)
|
|
|
|
# append to next input_ids and
|
|
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
|
|
next_attention_mask = tf.concat([attention_mask, next_attn_mask], axis=-1)
|
|
|
|
output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)[0]
|
|
output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[0]
|
|
|
|
self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
|
|
|
|
# select random slice
|
|
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
|
|
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
|
|
output_from_past_slice = output_from_past[:, :, random_slice_idx]
|
|
|
|
# test that outputs are equal for slice
|
|
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
|
|
|
|
|
|
@require_tf
|
|
class TFOPTModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|
all_model_classes = (TFOPTModel, TFOPTForCausalLM) if is_tf_available() else ()
|
|
all_generative_model_classes = (TFOPTForCausalLM,) if is_tf_available() else ()
|
|
pipeline_model_mapping = (
|
|
{"feature-extraction": TFOPTModel, "text-generation": TFOPTForCausalLM} if is_tf_available() else {}
|
|
)
|
|
is_encoder_decoder = False
|
|
test_pruning = False
|
|
test_onnx = False
|
|
onnx_min_opset = 10
|
|
|
|
def setUp(self):
|
|
self.model_tester = TFOPTModelTester(self)
|
|
self.config_tester = ConfigTester(self, config_class=OPTConfig)
|
|
|
|
def test_config(self):
|
|
self.config_tester.run_common_tests()
|
|
|
|
def test_decoder_model_past_large_inputs(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
|
self.model_tester.check_decoder_model_past_large_inputs(*config_and_inputs)
|
|
|
|
def test_resize_token_embeddings(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
def _get_word_embedding_weight(model, embedding_layer):
|
|
if hasattr(embedding_layer, "weight"):
|
|
return embedding_layer.weight
|
|
else:
|
|
# Here we build the word embeddings weights if not exists.
|
|
# And then we retry to get the attribute once built.
|
|
model.build_in_name_scope()
|
|
if hasattr(embedding_layer, "weight"):
|
|
return embedding_layer.weight
|
|
else:
|
|
return None
|
|
|
|
for model_class in self.all_model_classes:
|
|
for size in [config.vocab_size - 10, config.vocab_size + 10]:
|
|
# build the embeddings
|
|
model = model_class(config=config)
|
|
old_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
|
|
old_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
|
|
|
|
# reshape the embeddings
|
|
model.resize_token_embeddings(size)
|
|
new_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
|
|
new_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
|
|
|
|
# check that the resized embeddings size matches the desired size.
|
|
assert_size = size if size is not None else config.vocab_size
|
|
|
|
self.assertEqual(new_input_embeddings.shape[0], assert_size)
|
|
|
|
# check that weights remain the same after resizing
|
|
models_equal = True
|
|
for p1, p2 in zip(old_input_embeddings.value(), new_input_embeddings.value()):
|
|
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
|
models_equal = False
|
|
self.assertTrue(models_equal)
|
|
|
|
if old_output_embeddings is not None and new_output_embeddings is not None:
|
|
self.assertEqual(new_output_embeddings.shape[0], assert_size)
|
|
|
|
models_equal = True
|
|
for p1, p2 in zip(old_output_embeddings.value(), new_output_embeddings.value()):
|
|
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
|
models_equal = False
|
|
self.assertTrue(models_equal)
|
|
|
|
|
|
def _long_tensor(tok_lst):
|
|
return tf.constant(tok_lst, dtype=tf.int32)
|
|
|
|
|
|
@require_tf
|
|
class TFOPTHeadTests(unittest.TestCase):
|
|
vocab_size = 99
|
|
|
|
def _get_config_and_data(self):
|
|
eos_column_vector = tf.ones((4, 1), dtype=tf.int32) * 2
|
|
input_ids = tf.concat([ids_tensor((4, 6), self.vocab_size - 3) + 3, eos_column_vector], axis=1)
|
|
batch_size = input_ids.shape[0]
|
|
config = OPTConfig(
|
|
vocab_size=self.vocab_size,
|
|
hidden_size=24,
|
|
num_hidden_layers=2,
|
|
num_attention_heads=2,
|
|
ffn_dim=32,
|
|
max_position_embeddings=48,
|
|
eos_token_id=2,
|
|
pad_token_id=1,
|
|
bos_token_id=0,
|
|
)
|
|
return config, input_ids, batch_size
|
|
|
|
|
|
@require_sentencepiece
|
|
@require_tf
|
|
class OPTModelIntegrationTests(unittest.TestCase):
|
|
@slow
|
|
def test_inference_no_head(self):
|
|
model = TFOPTModel.from_pretrained("facebook/opt-350m")
|
|
input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
|
|
attention_mask = tf.not_equal(input_ids, model.config.pad_token_id)
|
|
with tf.GradientTape():
|
|
output = model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
|
|
expected_shape = (1, 11, 512)
|
|
self.assertEqual(output.shape, expected_shape)
|
|
expected_slice = tf.constant(
|
|
[[-0.2873, -1.9218, -0.3033], [-1.2710, -0.1338, -0.1902], [0.4095, 0.1214, -1.3121]]
|
|
)
|
|
self.assertTrue(np.allclose(output[:, :3, :3], expected_slice, atol=4e-3))
|
|
|
|
xla_generate = tf.function(model, jit_compile=True)
|
|
output = xla_generate(input_ids, attention_mask)[0]
|
|
self.assertTrue(np.allclose(output[:, :3, :3], expected_slice, atol=4e-2))
|
|
|
|
|
|
@require_tf
|
|
@slow
|
|
class TFOPTEmbeddingsTest(unittest.TestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.path_model = "facebook/opt-350m"
|
|
|
|
def test_logits(self):
|
|
model = TFOPTForCausalLM.from_pretrained(self.path_model)
|
|
tokenizer = GPT2Tokenizer.from_pretrained(self.path_model)
|
|
|
|
prompts = [
|
|
"Today is a beautiful day and I want to",
|
|
"In the city of",
|
|
"Paris is the capital of France and",
|
|
"Computers and mobile phones have taken",
|
|
]
|
|
# verify that prompt without BOS token is identical to Metaseq -> add_special_tokens=False
|
|
inputs = tokenizer(prompts, return_tensors="tf", padding=True, add_special_tokens=False)
|
|
logits = tf.math.reduce_mean(model(inputs.input_ids, attention_mask=inputs.attention_mask)[0], axis=-1)
|
|
logits_meta = tf.constant(
|
|
[
|
|
[1.3851, -13.8923, -10.5229, -10.7533, -0.2309, -10.2384, -0.5365, -9.0947, -5.1670],
|
|
[-4.7073, -10.6276, -3.9415, -21.5242, -0.2822, -0.2822, -0.2822, -0.2822, -0.2822],
|
|
[0.6247, -3.4229, -8.9179, -1.4297, -14.1650, 1.4146, -9.0218, -0.2703, -0.2703],
|
|
[6.4783, -1.9913, -10.7926, -2.3336, 1.5092, -0.9974, -6.8213, 1.3477, 1.3477],
|
|
]
|
|
)
|
|
self.assertTrue(np.allclose(logits, logits_meta, atol=1e-4))
|
|
|
|
xla_generate = tf.function(model, jit_compile=True)
|
|
logits = tf.math.reduce_mean(xla_generate(inputs.input_ids, attention_mask=inputs.attention_mask)[0], axis=-1)
|
|
self.assertTrue(np.allclose(logits, logits_meta, atol=1e-4))
|
|
|
|
|
|
@require_tf
|
|
@slow
|
|
class TFOPTGenerationTest(unittest.TestCase):
|
|
@property
|
|
def prompts(self):
|
|
return [
|
|
"Today is a beautiful day and I want",
|
|
"In the city of",
|
|
"Paris is the capital of France and",
|
|
"Computers and mobile phones have taken",
|
|
]
|
|
|
|
def test_generation_pre_attn_layer_norm(self):
|
|
model_id = "facebook/opt-125m"
|
|
|
|
EXPECTED_OUTPUTS = [
|
|
"Today is a beautiful day and I want to",
|
|
"In the city of New York, the city",
|
|
"Paris is the capital of France and the capital",
|
|
"Computers and mobile phones have taken over the",
|
|
]
|
|
|
|
predicted_outputs = []
|
|
tokenizer = GPT2Tokenizer.from_pretrained(model_id)
|
|
model = TFOPTForCausalLM.from_pretrained(model_id)
|
|
|
|
for prompt in self.prompts:
|
|
input_ids = tokenizer(prompt, return_tensors="tf").input_ids
|
|
|
|
generated_ids = model.generate(input_ids, max_length=10)
|
|
|
|
generated_string = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
|
predicted_outputs += generated_string
|
|
|
|
self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
|
|
|
|
def test_batch_generation(self):
|
|
model_id = "facebook/opt-350m"
|
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained(model_id)
|
|
model = TFOPTForCausalLM.from_pretrained(model_id)
|
|
|
|
tokenizer.padding_side = "left"
|
|
|
|
# use different length sentences to test batching
|
|
sentences = [
|
|
"Hello, my dog is a little",
|
|
"Today, I",
|
|
]
|
|
|
|
inputs = tokenizer(sentences, return_tensors="tf", padding=True)
|
|
input_ids = inputs["input_ids"]
|
|
|
|
outputs = model.generate(input_ids=input_ids, attention_mask=inputs["attention_mask"])
|
|
|
|
inputs_non_padded = tokenizer(sentences[0], return_tensors="tf").input_ids
|
|
output_non_padded = model.generate(input_ids=inputs_non_padded)
|
|
|
|
num_paddings = inputs_non_padded.shape[-1] - tf.math.reduce_sum(
|
|
tf.cast(inputs["attention_mask"][-1], tf.int64)
|
|
)
|
|
inputs_padded = tokenizer(sentences[1], return_tensors="tf").input_ids
|
|
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
|
|
|
|
batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
|
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
|
|
padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
|
|
|
|
expected_output_sentence = [
|
|
"Hello, my dog is a little bit of a dork.\nI'm a little bit",
|
|
"Today, I was in the middle of a conversation with a friend about the",
|
|
]
|
|
self.assertListEqual(expected_output_sentence, batch_out_sentence)
|
|
self.assertListEqual(batch_out_sentence, [non_padded_sentence, padded_sentence])
|
|
|
|
def test_generation_post_attn_layer_norm(self):
|
|
model_id = "facebook/opt-350m"
|
|
|
|
EXPECTED_OUTPUTS = [
|
|
"Today is a beautiful day and I want to",
|
|
"In the city of San Francisco, the city",
|
|
"Paris is the capital of France and the capital",
|
|
"Computers and mobile phones have taken over the",
|
|
]
|
|
|
|
predicted_outputs = []
|
|
tokenizer = GPT2Tokenizer.from_pretrained(model_id)
|
|
model = TFOPTForCausalLM.from_pretrained(model_id)
|
|
|
|
for prompt in self.prompts:
|
|
input_ids = tokenizer(prompt, return_tensors="tf").input_ids
|
|
|
|
generated_ids = model.generate(input_ids, max_length=10)
|
|
|
|
generated_string = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
|
predicted_outputs += generated_string
|
|
|
|
self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
|