mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
parent
808d6c50f8
commit
63ca6d9771
@ -397,6 +397,8 @@ class FlaxGenerationMixin:
|
||||
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
||||
)
|
||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||||
else: # by default let's always generate 10 new tokens
|
||||
generation_config.max_length = generation_config.max_length + input_ids_seq_length
|
||||
|
||||
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
|
||||
raise ValueError(
|
||||
|
@ -101,6 +101,10 @@ class FlaxGenerationTesterMixin:
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, flax_model.params)
|
||||
|
||||
# Generate max 5 tokens only otherwise seems to be numerical error accumulation
|
||||
pt_model.generation_config.max_length = 5
|
||||
flax_model.generation_config.max_length = 5
|
||||
|
||||
flax_generation_outputs = flax_model.generate(input_ids).sequences
|
||||
pt_generation_outputs = pt_model.generate(torch.tensor(input_ids, dtype=torch.long))
|
||||
|
||||
|
@ -3002,7 +3002,7 @@ class ModelTesterMixin:
|
||||
|
||||
def test_inputs_embeds_matches_input_ids_with_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if model_class.__name__ not in [
|
||||
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES),
|
||||
*get_values(MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES),
|
||||
|
Loading…
Reference in New Issue
Block a user