mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Generate: fix TF XLA tests on models with max_position_embeddings
or max_target_positions
(#21389)
This commit is contained in:
parent
6342427353
commit
19d67bfecb
@ -1865,6 +1865,17 @@ class TFModelTesterMixin:
|
||||
config.eos_token_id = None # Generate until max length
|
||||
config.do_sample = False
|
||||
|
||||
# fix config for models with additional sequence-length limiting settings
|
||||
for var_name in ["max_position_embeddings", "max_target_positions"]:
|
||||
attr = getattr(config, var_name, None)
|
||||
if attr is not None and attr < generate_kwargs["max_new_tokens"]:
|
||||
try:
|
||||
setattr(config, var_name, generate_kwargs["max_new_tokens"])
|
||||
except NotImplementedError:
|
||||
# xlnet will raise an exception when trying to set
|
||||
# max_position_embeddings.
|
||||
pass
|
||||
|
||||
model = model_class(config)
|
||||
|
||||
if model.supports_xla_generation:
|
||||
|
Loading…
Reference in New Issue
Block a user