mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Generate: filter encoder inputs when its signature does not accept wildcards (#21603)
This commit is contained in:
parent
41fa672df1
commit
13e03e619d
@ -1078,18 +1078,24 @@ class TFGenerationMixin:
|
||||
def _prepare_encoder_decoder_kwargs_for_generation(
|
||||
self, inputs_tensor: tf.Tensor, model_kwargs, model_input_name: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
# get encoder and store encoder outputs
|
||||
# 1. get encoder and store encoder outputs
|
||||
encoder = self.get_encoder()
|
||||
|
||||
# prepare encoder args and encoder kwargs from model kwargs
|
||||
# 2. prepare encoder args and encoder kwargs from model kwargs
|
||||
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
|
||||
encoder_kwargs = {
|
||||
argument: value
|
||||
for argument, value in model_kwargs.items()
|
||||
if not any(argument.startswith(p) for p in irrelevant_prefix)
|
||||
}
|
||||
encoder_signature = set(inspect.signature(encoder.call).parameters)
|
||||
encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
|
||||
if not encoder_accepts_wildcard:
|
||||
encoder_kwargs = {
|
||||
argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
|
||||
}
|
||||
|
||||
# vision models don't use `attention_mask`.
|
||||
# 3. vision models don't use `attention_mask`.
|
||||
encoder_kwargs["return_dict"] = True
|
||||
encoder_kwargs[model_input_name] = inputs_tensor
|
||||
if model_input_name != self.main_input_name: # in Keras, the first input must always be passed
|
||||
|
@ -609,13 +609,19 @@ class GenerationMixin:
|
||||
# 1. get encoder
|
||||
encoder = self.get_encoder()
|
||||
|
||||
# 2. prepare encoder args and encoder kwargs from model kwargs
|
||||
# 2. Prepare encoder args and encoder kwargs from model kwargs.
|
||||
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
|
||||
encoder_kwargs = {
|
||||
argument: value
|
||||
for argument, value in model_kwargs.items()
|
||||
if not any(argument.startswith(p) for p in irrelevant_prefix)
|
||||
}
|
||||
encoder_signature = set(inspect.signature(encoder.forward).parameters)
|
||||
encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
|
||||
if not encoder_accepts_wildcard:
|
||||
encoder_kwargs = {
|
||||
argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
|
||||
}
|
||||
|
||||
# 3. make sure that encoder returns `ModelOutput`
|
||||
model_input_name = model_input_name if model_input_name is not None else self.main_input_name
|
||||
|
@ -16,6 +16,8 @@
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import is_tf_available
|
||||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
@ -32,6 +34,7 @@ if is_tf_available():
|
||||
TFAutoModelForSeq2SeqLM,
|
||||
TFAutoModelForSpeechSeq2Seq,
|
||||
TFAutoModelForVision2Seq,
|
||||
TFBartForConditionalGeneration,
|
||||
TFLogitsProcessorList,
|
||||
TFMinLengthLogitsProcessor,
|
||||
tf_top_k_top_p_filtering,
|
||||
@ -264,3 +267,38 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
|
||||
tf.random.set_seed(0)
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
def test_model_kwarg_encoder_signature_filtering(self):
|
||||
# Has PT equivalent: ample use of framework-specific code
|
||||
bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
article = """Hugging Face is a technology company based in New York and Paris."""
|
||||
input_ids = bart_tokenizer(article, return_tensors="tf").input_ids
|
||||
bart_model = TFBartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
output = bart_model.generate(input_ids).numpy()
|
||||
|
||||
# Let's create a fake model that has a different signature. In particular, this fake model accepts "foo" as an
|
||||
# argument. Because "foo" is not in the encoder signature and doesn't start with "decoder_", it will be part of
|
||||
# the encoder kwargs prior to signature filtering, which would lead to an exception. But filtering kicks in and
|
||||
# saves the day.
|
||||
class FakeBart(TFBartForConditionalGeneration):
|
||||
def call(self, input_ids, foo=None, **kwargs):
|
||||
return super().call(input_ids, **kwargs)
|
||||
|
||||
bart_model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
fake_output = bart_model.generate(input_ids, foo="bar").numpy()
|
||||
self.assertTrue(np.array_equal(output, fake_output))
|
||||
|
||||
# Encoder signature filtering only kicks in if it doesn't accept wildcard kwargs. The following test will fail
|
||||
# because it doesn't do signature filtering.
|
||||
class FakeEncoder(bart_model.model.encoder.__class__):
|
||||
def call(self, input_ids, **kwargs):
|
||||
return super().call(input_ids, **kwargs)
|
||||
|
||||
fake_encoder = FakeEncoder(bart_model.config, bart_model.model.shared)
|
||||
bart_model.model.encoder = fake_encoder
|
||||
|
||||
# Normal generation still works (the output will be different because the encoder weights are different)
|
||||
fake_output = bart_model.generate(input_ids).numpy()
|
||||
with self.assertRaises(ValueError):
|
||||
# FakeEncoder.call() accepts **kwargs -> no filtering -> value error due to unexpected input "foo"
|
||||
bart_model.generate(input_ids, foo="bar")
|
||||
|
@ -17,6 +17,8 @@
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import is_torch_available, pipeline
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
@ -2439,30 +2441,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
max_score_diff = (output_sequences_batched.scores[0][1] - output_sequences.scores[0][0]).abs().max()
|
||||
self.assertTrue(max_score_diff < 1e-5)
|
||||
|
||||
def test_generate_from_input_embeds_decoder_only(self):
|
||||
# PT-only test: TF doesn't have a model with support to generate from input embeds (yet ;))
|
||||
# Note: the model must support generation from input embeddings
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
|
||||
text = "Hello world"
|
||||
input_ids = tokenizer.encode(text, return_tensors="pt")
|
||||
|
||||
# Traditional way of generating text
|
||||
outputs_from_ids = model.generate(input_ids)
|
||||
|
||||
# Same thing, but from input embeddings
|
||||
inputs_embeds = model.transformer.wte(input_ids)
|
||||
outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds)
|
||||
self.assertListEqual(outputs_from_ids.tolist(), outputs_from_embeds.tolist())
|
||||
|
||||
# But if we pass different inputs_embeds, we should get different outputs
|
||||
torch.manual_seed(0)
|
||||
random_embeds = torch.rand_like(inputs_embeds)
|
||||
outputs_from_rand_embeds = model.generate(input_ids, inputs_embeds=random_embeds)
|
||||
with self.assertRaises(AssertionError):
|
||||
self.assertListEqual(outputs_from_rand_embeds.tolist(), outputs_from_embeds.tolist())
|
||||
|
||||
def test_eos_token_id_int_and_list_top_k_top_sampling(self):
|
||||
# Has TF equivalent: this test relies on random sampling
|
||||
generation_kwargs = {
|
||||
@ -2490,6 +2468,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
def test_generate_from_inputs_embeds_decoder_only(self):
|
||||
# PT-only test: TF doesn't have a model with support to generate from input embeds (yet ;))
|
||||
# Note: the model must support generation from input embeddings
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
@ -2523,3 +2502,40 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
outputs_from_embeds[:, inputs_embeds.shape[1] :].tolist(),
|
||||
outputs_from_embeds_wo_ids[:, 1:].tolist(),
|
||||
)
|
||||
|
||||
def test_model_kwarg_encoder_signature_filtering(self):
|
||||
# Has TF equivalent: ample use of framework-specific code
|
||||
bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
article = """Hugging Face is a technology company based in New York and Paris."""
|
||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
|
||||
torch_device
|
||||
)
|
||||
output = bart_model.generate(input_ids).cpu().numpy()
|
||||
|
||||
# Let's create a fake model that has a different signature. In particular, this fake model accepts "foo" as an
|
||||
# argument. Because "foo" is not in the encoder signature and doesn't start with "decoder_", it will be part of
|
||||
# the encoder kwargs prior to signature filtering, which would lead to an exception. But filtering kicks in and
|
||||
# saves the day.
|
||||
class FakeBart(BartForConditionalGeneration):
|
||||
def forward(self, input_ids, foo=None, **kwargs):
|
||||
return super().forward(input_ids, **kwargs)
|
||||
|
||||
bart_model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-bart").to(torch_device)
|
||||
fake_output = bart_model.generate(input_ids, foo="bar").cpu().numpy()
|
||||
self.assertTrue(np.array_equal(output, fake_output))
|
||||
|
||||
# Encoder signature filtering only kicks in if it doesn't accept wildcard kwargs. The following test will fail
|
||||
# because it doesn't do signature filtering.
|
||||
class FakeEncoder(bart_model.model.encoder.__class__):
|
||||
def forward(self, input_ids, **kwargs):
|
||||
return super().forward(input_ids, **kwargs)
|
||||
|
||||
fake_encoder = FakeEncoder(bart_model.config, bart_model.model.shared).to(torch_device)
|
||||
bart_model.model.encoder = fake_encoder
|
||||
|
||||
# Normal generation still works (the output will be different because the encoder weights are different)
|
||||
fake_output = bart_model.generate(input_ids).cpu().numpy()
|
||||
with self.assertRaises(TypeError):
|
||||
# FakeEncoder.forward() accepts **kwargs -> no filtering -> type error due to unexpected input "foo"
|
||||
bart_model.generate(input_ids, foo="bar")
|
||||
|
Loading…
Reference in New Issue
Block a user