Fix Longformer and LED (#9942)

* Fix Longformer and LED

* Add a test for graph execution with inputs_embeds

* Apply style
This commit is contained in:
Julien Plu 2021-02-03 12:26:32 +01:00 committed by GitHub
parent d55e10beab
commit 3f77c26d74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 42 additions and 12 deletions

View File

@ -1665,7 +1665,6 @@ class TFLEDEncoder(tf.keras.layers.Layer):
def compute_hidden_states(self, hidden_states, padding_len):
return hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
@tf.function
def _pad_to_window_size(
self,
input_ids,
@ -1685,26 +1684,28 @@ class TFLEDEncoder(tf.keras.layers.Layer):
batch_size, seq_len = input_shape[:2]
padding_len = (attention_window - seq_len % attention_window) % attention_window
if padding_len > 0:
if tf.math.greater(padding_len, 0):
logger.info(
"Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format(
seq_len, seq_len + padding_len, attention_window
)
)
paddings = tf.convert_to_tensor([[0, 0], [0, padding_len]])
paddings = tf.convert_to_tensor([[0, 0], [0, padding_len]])
if input_ids is not None:
input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id)
if input_ids is not None:
input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id)
if inputs_embeds is not None:
if inputs_embeds is not None:
def pad_embeddings():
input_ids_padding = tf.fill((batch_size, padding_len), pad_token_id)
inputs_embeds_padding = self.embed_tokens(input_ids_padding)
inputs_embeds = tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2)
return tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2)
attention_mask = tf.pad(
attention_mask, paddings, constant_values=False
) # no attention on the padding tokens
inputs_embeds = tf.cond(tf.math.greater(padding_len, 0), pad_embeddings, lambda: inputs_embeds)
attention_mask = tf.pad(attention_mask, paddings, constant_values=False) # no attention on the padding tokens
return (
padding_len,

View File

@ -1836,7 +1836,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
batch_size, seq_len = input_shape[:2]
padding_len = (attention_window - seq_len % attention_window) % attention_window
if padding_len > 0:
if tf.math.greater(padding_len, 0):
logger.info(
"Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format(
seq_len, seq_len + padding_len, attention_window
@ -1859,7 +1859,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
inputs_embeds_padding = self.embeddings(input_ids_padding)
return tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2)
inputs_embeds = tf.cond(padding_len > 0, pad_embeddings, lambda: inputs_embeds)
inputs_embeds = tf.cond(tf.math.greater(padding_len, 0), pad_embeddings, lambda: inputs_embeds)
attention_mask = tf.pad(attention_mask, paddings, constant_values=False) # no attention on the padding tokens
token_type_ids = tf.pad(token_type_ids, paddings, constant_values=0) # pad with token_type_id = 0

View File

@ -884,6 +884,35 @@ class TFModelTesterMixin:
model(inputs)
def test_graph_mode_with_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
if not self.is_encoder_decoder:
input_ids = inputs["input_ids"]
del inputs["input_ids"]
else:
encoder_input_ids = inputs["input_ids"]
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
del inputs["input_ids"]
inputs.pop("decoder_input_ids", None)
if not self.is_encoder_decoder:
inputs["inputs_embeds"] = model.get_input_embeddings()(input_ids)
else:
inputs["inputs_embeds"] = model.get_input_embeddings()(encoder_input_ids)
inputs["decoder_inputs_embeds"] = model.get_input_embeddings()(decoder_input_ids)
@tf.function
def run_in_graph_mode():
return model(inputs)
outputs = run_in_graph_mode()
self.assertIsNotNone(outputs)
def test_numpy_arrays_inputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()