mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Mamba: add generative tests (#31478)
This commit is contained in:
parent
7d683f7bae
commit
83259e406d
@ -1830,6 +1830,12 @@ class GenerationMixin:
|
||||
raise ValueError("assisted generate requires `use_cache=True`")
|
||||
if generation_config.cache_implementation == "static":
|
||||
raise ValueError("assisted generate is not supported with `static_cache`")
|
||||
if self._is_stateful:
|
||||
# In assisted generation we need the ability to confirm whether the model would pick certain tokens,
|
||||
# which is not possible with stateful models (they can't reset to a previous subset of generated text)
|
||||
raise ValueError(
|
||||
f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}"
|
||||
)
|
||||
|
||||
# 11. Get the candidate generator, given the parameterization
|
||||
candidate_generator = self._get_candidate_generator(
|
||||
@ -1867,6 +1873,11 @@ class GenerationMixin:
|
||||
elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
|
||||
if not model_kwargs["use_cache"]:
|
||||
raise ValueError("Contrastive search requires `use_cache=True`")
|
||||
if self._is_stateful:
|
||||
# Just like assisted generation, we need to be able to rollback to a previous state (see comment above)
|
||||
raise ValueError(
|
||||
f"contrastive search is not supported with stateful models, such as {self.__class__.__name__}"
|
||||
)
|
||||
|
||||
result = self._contrastive_search(
|
||||
input_ids,
|
||||
|
@ -1281,6 +1281,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
is_parallelizable = False
|
||||
supports_gradient_checkpointing = False
|
||||
_is_stateful = False
|
||||
|
||||
# Flash Attention 2 support
|
||||
_supports_flash_attn_2 = False
|
||||
|
@ -1266,6 +1266,7 @@ class JambaPreTrainedModel(PreTrainedModel):
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache
|
||||
_is_stateful = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
|
@ -354,6 +354,7 @@ class MambaPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "backbone"
|
||||
_no_split_modules = ["MambaBlock"]
|
||||
supports_gradient_checkpointing = True
|
||||
_is_stateful = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
|
@ -394,6 +394,7 @@ class RwkvPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["RwkvBlock"]
|
||||
_keep_in_fp32_modules = ["time_decay", "time_first"]
|
||||
supports_gradient_checkpointing = True
|
||||
_is_stateful = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
|
@ -102,7 +102,11 @@ class GenerationTesterMixin:
|
||||
if isinstance(config.eos_token_id, int):
|
||||
config.eos_token_id = [config.eos_token_id]
|
||||
config.pad_token_id = config.eos_token_id[0]
|
||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
||||
|
||||
if self.has_attentions:
|
||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
# It is important set set the eos_token_id to None to ensure that no sequences
|
||||
# shorter than `max_length` can be generated
|
||||
@ -437,7 +441,7 @@ class GenerationTesterMixin:
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
@ -471,7 +475,7 @@ class GenerationTesterMixin:
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
@ -529,7 +533,7 @@ class GenerationTesterMixin:
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
@ -595,7 +599,7 @@ class GenerationTesterMixin:
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
@ -642,7 +646,7 @@ class GenerationTesterMixin:
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
@ -733,7 +737,7 @@ class GenerationTesterMixin:
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
@ -834,7 +838,7 @@ class GenerationTesterMixin:
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
@ -952,7 +956,7 @@ class GenerationTesterMixin:
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
@ -973,6 +977,9 @@ class GenerationTesterMixin:
|
||||
|
||||
def test_contrastive_generate(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if model_class._is_stateful:
|
||||
self.skipTest("Stateful models don't support contrastive search generation")
|
||||
|
||||
# won't fix: FSMT and Reformer have a different cache variable type (and format).
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
@ -997,6 +1004,9 @@ class GenerationTesterMixin:
|
||||
|
||||
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if model_class._is_stateful:
|
||||
self.skipTest("Stateful models don't support contrastive search generation")
|
||||
|
||||
# won't fix: FSMT and Reformer have a different cache variable type (and format).
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
@ -1017,7 +1027,7 @@ class GenerationTesterMixin:
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
@ -1030,9 +1040,12 @@ class GenerationTesterMixin:
|
||||
def test_contrastive_generate_low_memory(self):
|
||||
# Check that choosing 'low_memory' does not change the model output
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if model_class._is_stateful:
|
||||
self.skipTest("Stateful models don't support contrastive search generation")
|
||||
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer", "speech2text"]):
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode", "jamba"]):
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode"]):
|
||||
self.skipTest("TODO: fix me")
|
||||
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
|
||||
@ -1069,6 +1082,8 @@ class GenerationTesterMixin:
|
||||
def test_beam_search_low_memory(self):
|
||||
# Check that choosing 'low_memory' does not change the model output
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if model_class._is_stateful:
|
||||
self.skipTest("May fix in the future: need custom cache handling")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
if any(
|
||||
@ -1115,6 +1130,8 @@ class GenerationTesterMixin:
|
||||
# - assisted_decoding does not support `batch_size > 1`
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if model_class._is_stateful:
|
||||
self.skipTest("Stateful models don't support assisted generation")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
if any(
|
||||
@ -1156,7 +1173,7 @@ class GenerationTesterMixin:
|
||||
"output_scores": True,
|
||||
"output_logits": True,
|
||||
"output_hidden_states": True,
|
||||
"output_attentions": True,
|
||||
"output_attentions": self.has_attentions,
|
||||
"return_dict_in_generate": True,
|
||||
}
|
||||
output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
@ -1184,6 +1201,8 @@ class GenerationTesterMixin:
|
||||
# This test is mostly a copy of test_assisted_decoding_matches_greedy_search
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if model_class._is_stateful:
|
||||
self.skipTest("Stateful models don't support assisted generation")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
if any(
|
||||
@ -1225,7 +1244,7 @@ class GenerationTesterMixin:
|
||||
"output_scores": True,
|
||||
"output_logits": True,
|
||||
"output_hidden_states": True,
|
||||
"output_attentions": True,
|
||||
"output_attentions": self.has_attentions,
|
||||
"return_dict_in_generate": True,
|
||||
}
|
||||
|
||||
@ -1244,6 +1263,8 @@ class GenerationTesterMixin:
|
||||
# match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with
|
||||
# different shapes, see https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535).
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if model_class._is_stateful:
|
||||
self.skipTest("Stateful models don't support assisted generation")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
if any(
|
||||
@ -1289,7 +1310,7 @@ class GenerationTesterMixin:
|
||||
"output_scores": True,
|
||||
"output_logits": True,
|
||||
"output_hidden_states": True,
|
||||
"output_attentions": True,
|
||||
"output_attentions": self.has_attentions,
|
||||
"return_dict_in_generate": True,
|
||||
}
|
||||
output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
@ -1326,7 +1347,7 @@ class GenerationTesterMixin:
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
num_beams=1,
|
||||
output_attentions=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
remove_invalid_values=True,
|
||||
**{name: mask},
|
||||
@ -1344,6 +1365,10 @@ class GenerationTesterMixin:
|
||||
if len(self.all_generative_model_classes) == 0:
|
||||
self.skipTest(reason="No generative architecture available for this model.")
|
||||
|
||||
# - The model must support padding
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="This model doesn't support padding.")
|
||||
|
||||
# - The model must be a decoder-only architecture (encoder-based architectures use right-padding)
|
||||
decoder_only_classes = []
|
||||
for model_class in self.all_generative_model_classes:
|
||||
@ -1704,30 +1729,31 @@ class GenerationTesterMixin:
|
||||
self._check_logits(num_sequences_in_output, output.logits, config=config)
|
||||
|
||||
# Attentions
|
||||
if config.is_encoder_decoder:
|
||||
# encoder
|
||||
self._check_encoder_attention_for_generate(output.encoder_attentions, batch_size, config, seq_length)
|
||||
# decoder
|
||||
self._check_attentions_for_generate(
|
||||
num_sequences_in_output,
|
||||
output.decoder_attentions,
|
||||
min_length=1,
|
||||
max_length=output.sequences.shape[-1],
|
||||
config=config,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
else:
|
||||
# if use_cache first input is equal to no use_cache, so skip here
|
||||
attentions = output.attentions if not use_cache else output.attentions[1:]
|
||||
min_length = seq_length if not use_cache else seq_length + 1
|
||||
self._check_attentions_for_generate(
|
||||
num_sequences_in_output,
|
||||
attentions=attentions,
|
||||
min_length=min_length,
|
||||
max_length=output.sequences.shape[-1],
|
||||
config=config,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
if self.has_attentions:
|
||||
if config.is_encoder_decoder:
|
||||
# encoder
|
||||
self._check_encoder_attention_for_generate(output.encoder_attentions, batch_size, config, seq_length)
|
||||
# decoder
|
||||
self._check_attentions_for_generate(
|
||||
num_sequences_in_output,
|
||||
output.decoder_attentions,
|
||||
min_length=1,
|
||||
max_length=output.sequences.shape[-1],
|
||||
config=config,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
else:
|
||||
# if use_cache first input is equal to no use_cache, so skip here
|
||||
attentions = output.attentions if not use_cache else output.attentions[1:]
|
||||
min_length = seq_length if not use_cache else seq_length + 1
|
||||
self._check_attentions_for_generate(
|
||||
num_sequences_in_output,
|
||||
attentions=attentions,
|
||||
min_length=min_length,
|
||||
max_length=output.sequences.shape[-1],
|
||||
config=config,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
# Hidden States
|
||||
if config.is_encoder_decoder:
|
||||
@ -1763,7 +1789,7 @@ class GenerationTesterMixin:
|
||||
# 2. Some old models still return `output.past_key_values` even without `use_cache=True`
|
||||
# 3. TODO (joao): A few models have different formats/types, skipping those until the cache refactor is
|
||||
# complete
|
||||
models_without_standard_cache = ("bloom", "ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba")
|
||||
models_without_standard_cache = ("bloom", "ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba", "mamba")
|
||||
has_standard_cache = not any(
|
||||
model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache
|
||||
)
|
||||
|
@ -503,10 +503,6 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
# They should result in very similar logits
|
||||
self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=3e-3))
|
||||
|
||||
@unittest.skip("Jamba has its own special cache type") # FIXME: @gante
|
||||
def test_assisted_decoding_matches_greedy_search_0_random(self):
|
||||
pass
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@require_bitsandbytes
|
||||
|
@ -250,6 +250,8 @@ class MambaModelTester:
|
||||
@require_torch
|
||||
class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (MambaModel, MambaForCausalLM) if is_torch_available() else ()
|
||||
all_generative_model_classes = (MambaForCausalLM,) if is_torch_available() else ()
|
||||
has_attentions = False # Mamba does not support attentions
|
||||
fx_compatible = False # FIXME let's try to support this @ArthurZucker
|
||||
test_torchscript = False # FIXME let's try to support this @ArthurZucker
|
||||
test_missing_keys = False
|
||||
@ -292,10 +294,6 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
@unittest.skip("No attention in mamba")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
pass
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@ -364,14 +362,6 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
# check if it's a ones like
|
||||
self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5))
|
||||
|
||||
@unittest.skip("Mamba does not use attention")
|
||||
def test_attention_outputs(self):
|
||||
r"""
|
||||
Overriding the test_attention_outputs test as the attention outputs of Mamba are different from other models
|
||||
it has a shape `batch_size, seq_len, hidden_size`.
|
||||
"""
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model = MambaModel.from_pretrained("hf-internal-testing/mamba-130m")
|
||||
|
Loading…
Reference in New Issue
Block a user