Reduce memory usage in TF building (#24046)

* Make the default dummies (2, 2) instead of (3, 3)

* Fix for Funnel

* Actually fix Funnel
This commit is contained in:
Matt 2023-06-06 18:29:54 +01:00 committed by GitHub
parent 072188d638
commit 7203ea6797
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 3 deletions

View File

@ -1116,8 +1116,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
dummies = {}
sig = self._prune_signature(self.input_signature)
for key, spec in sig.items():
# 3 is the most correct arbitrary size. I will not be taking questions
dummies[key] = tf.ones(shape=[dim if dim is not None else 3 for dim in spec.shape], dtype=spec.dtype)
# 2 is the most correct arbitrary size. I will not be taking questions
dummies[key] = tf.ones(shape=[dim if dim is not None else 2 for dim in spec.shape], dtype=spec.dtype)
if key == "token_type_ids":
# Some models have token_type_ids but with a vocab_size of 1
dummies[key] = tf.zeros_like(dummies[key])
@ -1125,7 +1125,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
if "encoder_hidden_states" not in dummies:
if self.main_input_name == "input_ids":
dummies["encoder_hidden_states"] = tf.ones(
shape=(3, 3, self.config.hidden_size), dtype=tf.float32, name="encoder_hidden_states"
shape=(2, 2, self.config.hidden_size), dtype=tf.float32, name="encoder_hidden_states"
)
else:
raise NotImplementedError(

View File

@ -242,6 +242,7 @@ class TFFunnelAttentionStructure:
# rel_pos = tf.broadcast_to(rel_pos, (rel_pos.shape[0], self.d_model))
rel_pos = tf.cast(rel_pos, dtype=zero_offset.dtype)
rel_pos = rel_pos + zero_offset
tf.debugging.assert_less(rel_pos, tf.shape(pos_embed)[0])
position_embeds_no_pooling = tf.gather(pos_embed, rel_pos, axis=0)
position_embeds_list.append([position_embeds_no_pooling, position_embeds_pooling])
@ -974,6 +975,11 @@ class TFFunnelPreTrainedModel(TFPreTrainedModel):
config_class = FunnelConfig
base_model_prefix = "funnel"
@property
def dummy_inputs(self):
# Funnel misbehaves with very small inputs, so we override and make them a bit bigger
return {"input_ids": tf.ones((3, 3), dtype=tf.int32)}
@dataclass
class TFFunnelForPreTrainingOutput(ModelOutput):
@ -1424,6 +1430,10 @@ class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss):
self.funnel = TFFunnelBaseLayer(config, name="funnel")
self.classifier = TFFunnelClassificationHead(config, 1, name="classifier")
@property
def dummy_inputs(self):
return {"input_ids": tf.ones((3, 3, 4), dtype=tf.int32)}
@unpack_inputs
@add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(