mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Call _set_save_spec() when creating TF models (#19321)
* Add a build_from_serving_sig_and_dummies method and replace all calls like model(model.dummy_inputs) with it. * make fixup * Remove the overridden save() as this is no longer necessary * Also call _set_save_spec(), the last missing piece * Ensure we set the save spec when loading from config too * Turn this whole thing into a one-line PR * Turn this whole thing into a one-line PR * Turn this whole thing into a one-line PR Co-authored-by: Your Name <you@example.com>
This commit is contained in:
parent
c875a96eb1
commit
071df6eb13
@ -1049,6 +1049,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
# Save config and origin of the pretrained weights if given in model
|
||||
self.config = config
|
||||
self.name_or_path = config.name_or_path
|
||||
# Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec
|
||||
self._set_save_spec(self.serving.input_signature[0])
|
||||
|
||||
def get_config(self):
|
||||
return self.config.to_dict()
|
||||
@ -1097,29 +1099,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def save(
|
||||
self,
|
||||
filepath,
|
||||
overwrite=True,
|
||||
include_optimizer=True,
|
||||
save_format=None,
|
||||
signatures=None,
|
||||
options=None,
|
||||
save_traces=True,
|
||||
):
|
||||
# Very simple wrapper that ensures we set the correct serving signature when saving
|
||||
if signatures is None and hasattr(self, "serving"):
|
||||
signatures = self.serving
|
||||
super().save(
|
||||
filepath,
|
||||
overwrite=overwrite,
|
||||
include_optimizer=include_optimizer,
|
||||
save_format=save_format,
|
||||
signatures=signatures,
|
||||
options=options,
|
||||
save_traces=save_traces,
|
||||
)
|
||||
|
||||
def get_input_embeddings(self) -> tf.keras.layers.Layer:
|
||||
"""
|
||||
Returns the model's input embeddings layer.
|
||||
|
Loading…
Reference in New Issue
Block a user