mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Generate: validate model_kwargs on FLAX (and catch typos in generate arguments) (#18653)
This commit is contained in:
parent
0ea53822f8
commit
a541d97477
@ -15,9 +15,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -160,6 +161,24 @@ class FlaxGenerationMixin:
|
||||
"""
|
||||
return logits
|
||||
|
||||
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
|
||||
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
|
||||
unused_model_args = []
|
||||
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
|
||||
# `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
|
||||
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
|
||||
if "kwargs" in model_args:
|
||||
model_args |= set(inspect.signature(self.__call__).parameters)
|
||||
for key, value in model_kwargs.items():
|
||||
if value is not None and key not in model_args:
|
||||
unused_model_args.append(key)
|
||||
|
||||
if unused_model_args:
|
||||
raise ValueError(
|
||||
f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
|
||||
" generate arguments will also show up in this list)"
|
||||
)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_ids: jnp.ndarray,
|
||||
@ -262,6 +281,9 @@ class FlaxGenerationMixin:
|
||||
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
```"""
|
||||
# Validate model kwargs
|
||||
self._validate_model_kwargs(model_kwargs.copy())
|
||||
|
||||
# set init values
|
||||
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -26,6 +27,7 @@ if is_flax_available():
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax import jit
|
||||
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
|
||||
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
|
||||
|
||||
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
||||
@ -273,3 +275,22 @@ class FlaxGenerationTesterMixin:
|
||||
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
|
||||
|
||||
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxGenerationIntegrationTests(unittest.TestCase):
|
||||
def test_validate_generation_inputs(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-bert")
|
||||
model = FlaxAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||
|
||||
encoder_input_str = "Hello world"
|
||||
input_ids = tokenizer(encoder_input_str, return_tensors="np").input_ids
|
||||
|
||||
# typos are quickly detected (the correct argument is `do_sample`)
|
||||
with self.assertRaisesRegex(ValueError, "do_samples"):
|
||||
model.generate(input_ids, do_samples=True)
|
||||
|
||||
# arbitrary arguments that will not be used anywhere are also not accepted
|
||||
with self.assertRaisesRegex(ValueError, "foo"):
|
||||
fake_model_kwargs = {"foo": "bar"}
|
||||
model.generate(input_ids, **fake_model_kwargs)
|
||||
|
Loading…
Reference in New Issue
Block a user