Generate: validate model_kwargs on FLAX (and catch typos in generate arguments) (#18653)

This commit is contained in:
Joao Gante 2022-08-18 10:56:21 +01:00 committed by GitHub
parent 0ea53822f8
commit a541d97477
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 44 additions and 1 deletions

View File

@ -15,9 +15,10 @@
# limitations under the License.
import inspect
import warnings
from functools import partial
from typing import Dict, Optional
from typing import Any, Dict, Optional
import numpy as np
@ -160,6 +161,24 @@ class FlaxGenerationMixin:
"""
return logits
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
unused_model_args = []
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
# `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
if "kwargs" in model_args:
model_args |= set(inspect.signature(self.__call__).parameters)
for key, value in model_kwargs.items():
if value is not None and key not in model_args:
unused_model_args.append(key)
if unused_model_args:
raise ValueError(
f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
" generate arguments will also show up in this list)"
)
def generate(
self,
input_ids: jnp.ndarray,
@ -262,6 +281,9 @@ class FlaxGenerationMixin:
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
```"""
# Validate model kwargs
self._validate_model_kwargs(model_kwargs.copy())
# set init values
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id

View File

@ -13,6 +13,7 @@
# limitations under the License.
import random
import unittest
import numpy as np
@ -26,6 +27,7 @@ if is_flax_available():
import jax.numpy as jnp
from jax import jit
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
@ -273,3 +275,22 @@ class FlaxGenerationTesterMixin:
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
@require_flax
class FlaxGenerationIntegrationTests(unittest.TestCase):
def test_validate_generation_inputs(self):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-bert")
model = FlaxAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
encoder_input_str = "Hello world"
input_ids = tokenizer(encoder_input_str, return_tensors="np").input_ids
# typos are quickly detected (the correct argument is `do_sample`)
with self.assertRaisesRegex(ValueError, "do_samples"):
model.generate(input_ids, do_samples=True)
# arbitrary arguments that will not be used anywhere are also not accepted
with self.assertRaisesRegex(ValueError, "foo"):
fake_model_kwargs = {"foo": "bar"}
model.generate(input_ids, **fake_model_kwargs)