mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Generate: TF .generate()
can now be exported with dynamic length (#21474)
This commit is contained in:
parent
e69f9715eb
commit
2edf9a857b
@ -849,7 +849,7 @@ class TFGenerationMixin:
|
||||
input_ids = inputs_tensor
|
||||
|
||||
# 7. Prepare `max_length` depending on other stopping criteria.
|
||||
input_ids_seq_length = input_ids.shape[-1]
|
||||
input_ids_seq_length = shape_list(input_ids)[-1]
|
||||
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
||||
if has_default_max_length and generation_config.max_new_tokens is None:
|
||||
warnings.warn(
|
||||
@ -869,18 +869,23 @@ class TFGenerationMixin:
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
|
||||
raise ValueError(
|
||||
f"Unfeasable length constraints: the minimum length ({generation_config.min_length}) is larger than"
|
||||
f" the maximum length ({generation_config.max_length})"
|
||||
)
|
||||
if input_ids_seq_length >= generation_config.max_length:
|
||||
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
||||
logger.warning(
|
||||
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
|
||||
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
||||
" increasing`max_new_tokens`."
|
||||
)
|
||||
# If the input length is a tensor (i.e. dynamic length), skip length checks
|
||||
if not isinstance(input_ids_seq_length, tf.Tensor):
|
||||
if (
|
||||
generation_config.min_length is not None
|
||||
and generation_config.min_length > generation_config.max_length
|
||||
):
|
||||
raise ValueError(
|
||||
f"Unfeasable length constraints: the minimum length ({generation_config.min_length}) is larger"
|
||||
f" than the maximum length ({generation_config.max_length})"
|
||||
)
|
||||
if input_ids_seq_length >= generation_config.max_length:
|
||||
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
||||
logger.warning(
|
||||
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
|
||||
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
||||
" increasing`max_new_tokens`."
|
||||
)
|
||||
|
||||
# 8. determine generation mode
|
||||
is_contrastive_search_gen_mode = (
|
||||
|
@ -182,7 +182,7 @@ class TFAttention(tf.keras.layers.Layer):
|
||||
key = self.split_heads(key)
|
||||
value = self.split_heads(value)
|
||||
if layer_past is not None:
|
||||
past_key, past_value = tf.unstack(layer_past, axis=0)
|
||||
past_key, past_value = tf.unstack(layer_past, axis=0, num=2)
|
||||
key = tf.concat([past_key, key], axis=-2)
|
||||
value = tf.concat([past_value, value], axis=-2)
|
||||
|
||||
|
@ -144,9 +144,10 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
|
||||
}
|
||||
|
||||
@slow
|
||||
def test_generate_tf_function_export(self):
|
||||
def test_generate_tf_function_export_fixed_input_length(self):
|
||||
test_model = TFAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
max_length = 2
|
||||
input_length = 2
|
||||
max_new_tokens = 2
|
||||
|
||||
class DummyModel(tf.Module):
|
||||
def __init__(self, model):
|
||||
@ -155,8 +156,8 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
|
||||
|
||||
@tf.function(
|
||||
input_signature=(
|
||||
tf.TensorSpec((None, max_length), tf.int32, name="input_ids"),
|
||||
tf.TensorSpec((None, max_length), tf.int32, name="attention_mask"),
|
||||
tf.TensorSpec((None, input_length), tf.int32, name="input_ids"),
|
||||
tf.TensorSpec((None, input_length), tf.int32, name="attention_mask"),
|
||||
),
|
||||
jit_compile=True,
|
||||
)
|
||||
@ -164,7 +165,7 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
|
||||
outputs = self.model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_new_tokens=max_length,
|
||||
max_new_tokens=max_new_tokens,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
return {"sequences": outputs["sequences"]}
|
||||
@ -181,5 +182,47 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
|
||||
"attention_mask": tf.constant(dummy_attention_masks[:batch_size]),
|
||||
}
|
||||
tf_func_outputs = serving_func(**inputs)["sequences"]
|
||||
tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_length)
|
||||
tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_new_tokens)
|
||||
tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs)
|
||||
|
||||
@slow
|
||||
def test_generate_tf_function_export_fixed_batch_size(self):
|
||||
test_model = TFAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
batch_size = 1
|
||||
max_new_tokens = 2
|
||||
|
||||
class DummyModel(tf.Module):
|
||||
def __init__(self, model):
|
||||
super(DummyModel, self).__init__()
|
||||
self.model = model
|
||||
|
||||
@tf.function(
|
||||
input_signature=(
|
||||
tf.TensorSpec((batch_size, None), tf.int32, name="input_ids"),
|
||||
tf.TensorSpec((batch_size, None), tf.int32, name="attention_mask"),
|
||||
),
|
||||
jit_compile=True,
|
||||
)
|
||||
def serving(self, input_ids, attention_mask):
|
||||
outputs = self.model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_new_tokens=max_new_tokens,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
return {"sequences": outputs["sequences"]}
|
||||
|
||||
dummy_input_ids = [[2], [102, 103]]
|
||||
dummy_attention_masks = [[1], [1, 1]]
|
||||
dummy_model = DummyModel(model=test_model)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tf.saved_model.save(dummy_model, tmp_dir, signatures={"serving_default": dummy_model.serving})
|
||||
serving_func = tf.saved_model.load(tmp_dir).signatures["serving_default"]
|
||||
for input_row in range(len(dummy_input_ids)):
|
||||
inputs = {
|
||||
"input_ids": tf.constant([dummy_input_ids[input_row]]),
|
||||
"attention_mask": tf.constant([dummy_attention_masks[input_row]]),
|
||||
}
|
||||
tf_func_outputs = serving_func(**inputs)["sequences"]
|
||||
tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_new_tokens)
|
||||
tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs)
|
||||
|
Loading…
Reference in New Issue
Block a user