Stop storing references to bound methods via tf.function (#24146)

* Stop storing references to bound methods in tf.functions

* Remove the gc.collect calls now that we resolved the underlying problem

* Remove the default signature from model.serving entirely, big cleanup

* Remove _prune_signature as self.input_signature can prune itself

* Restore serving docstring

* Update int support test to check the input signature

* Make sure other tests also use model.input_signature and not serving.input_signature

* Restore _prune_signature

* Remove the doctest GC now it's no longer needed

* Correct core tests to use the pruned sig

* order lines correctly in core tests

* Add eager_serving back with a deprecation warning
This commit is contained in:
Matt 2023-06-13 19:04:22 +01:00 committed by GitHub
parent b979a2064d
commit 3bd1fe4315
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 34 additions and 48 deletions

View File

@ -1171,12 +1171,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
self.config = config
self.name_or_path = config.name_or_path
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
if not hasattr(self, "serving"): # Don't overwrite existing serving signatures
self.serving = tf.function(
self.eager_serving, input_signature=[self._prune_signature(self.input_signature)]
)
# Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec
self._set_save_spec(self.serving.input_signature[0])
self._set_save_spec(self._prune_signature(self.input_signature))
def get_config(self):
return self.config.to_dict()
@ -1226,15 +1222,31 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
head_mask = tf.cast(head_mask, tf.float32) # switch to float if need + fp16 compatibility
return head_mask
@tf.function
def serving(self, inputs):
"""
Args:
Method used for serving the model. Does not have a specific signature, but will be specialized as concrete
functions when saving with `save_pretrained`.
inputs (`Dict[str, tf.Tensor]`):
The input of the saved model as a dictionary of tensors.
"""
output = self.call(inputs)
return self.serving_output(output)
def eager_serving(self, inputs):
"""
Method used for serving the model. Intended not to be compiled with a tf.function decorator so that we can use
it to generate multiple signatures later.
Method used for serving the model. This method is deprecated, and will be removed.
Args:
inputs (`Dict[str, tf.Tensor]`):
The input of the saved model as a dictionary of tensors.
"""
warnings.warn(
"The function `eager_serving` is deprecated and will be removed in version 4.32.0 of Transformers",
FutureWarning,
)
output = self.call(inputs)
return self.serving_output(output)
@ -2409,17 +2421,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
if getattr(self.config, "torch_dtype", None) is not None and not isinstance(self.config.torch_dtype, str):
self.config.torch_dtype = str(self.config.torch_dtype).split(".")[1]
if signatures is None:
if any(spec.dtype == tf.int32 for spec in self.serving.input_signature[0].values()):
sig = self._prune_signature(self.input_signature)
serving_default = self.serving.get_concrete_function(sig)
if any(spec.dtype == tf.int32 for spec in sig.values()):
int64_spec = {
key: tf.TensorSpec(
shape=spec.shape, dtype=tf.int64 if spec.dtype == tf.int32 else spec.dtype, name=spec.name
)
for key, spec in self.serving.input_signature[0].items()
for key, spec in sig.items()
}
int64_serving = tf.function(self.eager_serving, input_signature=[int64_spec])
signatures = {"serving_default": self.serving, "int64_serving": int64_serving}
int64_serving = self.serving.get_concrete_function(int64_spec)
signatures = {"serving_default": serving_default, "int64_serving": int64_serving}
else:
signatures = self.serving
signatures = serving_default
saved_model_dir = os.path.join(save_directory, "saved_model", str(version))
self.save(saved_model_dir, include_optimizer=False, signatures=signatures)
logger.info(f"Saved model created in {saved_model_dir}")

View File

@ -1882,13 +1882,6 @@ def preprocess_string(string, skip_cuda_tests):
if not is_cuda_found:
modified_string = "".join(codeblocks)
if ">>>" in modified_string:
lines = modified_string.split("\n")
indent = len(lines[-1]) - len(lines[-1].lstrip())
cleanup = ">>> import gc; gc.collect() # doctest: +IGNORE_RESULT"
modified_string += "\n" + " " * indent + cleanup
return modified_string

View File

@ -2676,7 +2676,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.model = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="model")
self.model._set_save_spec(inputs=self.serving.input_signature)
self.model._set_save_spec(self._prune_signature(self.input_signature))
self.use_cache = config.use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
self.bias_layer = BiasLayer(

View File

@ -1,6 +1,5 @@
from __future__ import annotations
import gc
import json
import os
import shutil
@ -551,11 +550,6 @@ class TFRagDPRBartTest(TFRagTestMixin, unittest.TestCase):
@require_sentencepiece
@require_tokenizers
class TFRagModelIntegrationTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
@cached_property
def token_model(self):
return TFRagTokenForGeneration.from_pretrained_question_encoder_generator(

View File

@ -17,7 +17,6 @@
from __future__ import annotations
import gc
import inspect
import unittest
@ -431,11 +430,6 @@ def prepare_dog_img():
@require_tf
@slow
class TFSamModelIntegrationTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
def test_inference_mask_generation_no_point(self):
model = TFSamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

View File

@ -15,7 +15,6 @@
from __future__ import annotations
import gc
import unittest
from transformers import XGLMConfig, XGLMTokenizer, is_tf_available
@ -173,11 +172,6 @@ class TFXGLMModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase
@require_tf
class TFXGLMModelLanguageGenerationTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
@slow
def test_lm_generate_xglm(self, verify_outputs=True):
model = TFXGLMForCausalLM.from_pretrained("facebook/xglm-564M")

View File

@ -1687,14 +1687,10 @@ class TFModelTesterMixin:
if tensor.dtype.is_integer:
self.assertTrue(tensor.dtype == tf.int32, "Integer dummy inputs should be tf.int32!")
# Also confirm that the serving sig uses int32
if hasattr(model, "serving"):
serving_sig = model.serving.input_signature
for key, tensor_spec in serving_sig[0].items():
if tensor_spec.dtype.is_integer:
self.assertTrue(
tensor_spec.dtype == tf.int32, "Serving signatures should use tf.int32 for ints!"
)
# Also confirm that the input_signature uses int32
for key, tensor_spec in model.input_signature.items():
if tensor_spec.dtype.is_integer:
self.assertTrue(tensor_spec.dtype == tf.int32, "Input signatures should use tf.int32 for ints!")
def test_generate_with_headmasking(self):
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]

View File

@ -217,17 +217,18 @@ class TFCoreModelTesterMixin:
for model_class in self.all_model_classes:
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
class_sig = model._prune_signature(model.input_signature)
num_out = len(model(class_inputs_dict))
for key in list(class_inputs_dict.keys()):
# Remove keys not in the serving signature, as the SavedModel will not be compiled to deal with them
if key not in model.serving.input_signature[0]:
if key not in class_sig:
del class_inputs_dict[key]
# Check it's a tensor, in case the inputs dict has some bools in it too
elif isinstance(class_inputs_dict[key], tf.Tensor) and class_inputs_dict[key].dtype.is_integer:
class_inputs_dict[key] = tf.cast(class_inputs_dict[key], tf.int32)
if set(class_inputs_dict.keys()) != set(model.serving.input_signature[0].keys()):
if set(class_inputs_dict.keys()) != set(class_sig.keys()):
continue # Some models have inputs that the preparation functions don't create, we skip those
with tempfile.TemporaryDirectory() as tmpdirname: