Call _set_save_spec() when creating TF models (#19321)

* Add a build_from_serving_sig_and_dummies method and replace all calls like model(model.dummy_inputs) with it.

* make fixup

* Remove the overridden save() as this is no longer necessary

* Also call _set_save_spec(), the last missing piece

* Ensure we set the save spec when loading from config too

* Turn this whole thing into a one-line PR

* Turn this whole thing into a one-line PR

* Turn this whole thing into a one-line PR

Co-authored-by: Your Name <you@example.com>
This commit is contained in:
Matt 2022-10-05 18:03:49 +01:00 committed by GitHub
parent c875a96eb1
commit 071df6eb13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1049,6 +1049,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# Save config and origin of the pretrained weights if given in model
self.config = config
self.name_or_path = config.name_or_path
# Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec
self._set_save_spec(self.serving.input_signature[0])
def get_config(self):
return self.config.to_dict()
@ -1097,29 +1099,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
"""
raise NotImplementedError
def save(
self,
filepath,
overwrite=True,
include_optimizer=True,
save_format=None,
signatures=None,
options=None,
save_traces=True,
):
# Very simple wrapper that ensures we set the correct serving signature when saving
if signatures is None and hasattr(self, "serving"):
signatures = self.serving
super().save(
filepath,
overwrite=overwrite,
include_optimizer=include_optimizer,
save_format=save_format,
signatures=signatures,
options=options,
save_traces=save_traces,
)
def get_input_embeddings(self) -> tf.keras.layers.Layer:
"""
Returns the model's input embeddings layer.