diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index f7ac3d1fed1..dd1719294e8 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1830,6 +1830,12 @@ class GenerationMixin: raise ValueError("assisted generate requires `use_cache=True`") if generation_config.cache_implementation == "static": raise ValueError("assisted generate is not supported with `static_cache`") + if self._is_stateful: + # In assisted generation we need the ability to confirm whether the model would pick certain tokens, + # which is not possible with stateful models (they can't reset to a previous subset of generated text) + raise ValueError( + f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}" + ) # 11. Get the candidate generator, given the parameterization candidate_generator = self._get_candidate_generator( @@ -1867,6 +1873,11 @@ class GenerationMixin: elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH: if not model_kwargs["use_cache"]: raise ValueError("Contrastive search requires `use_cache=True`") + if self._is_stateful: + # Just like assisted generation, we need to be able to rollback to a previous state (see comment above) + raise ValueError( + f"contrastive search is not supported with stateful models, such as {self.__class__.__name__}" + ) result = self._contrastive_search( input_ids, diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 1b0456eff92..f7b0db6d77f 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1281,6 +1281,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix is_parallelizable = False supports_gradient_checkpointing = False + _is_stateful = False # Flash Attention 2 support _supports_flash_attn_2 = False diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index bcb520e8211..f49f55f5779 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -1266,6 +1266,7 @@ class JambaPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache + _is_stateful = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 82cbef3033d..04430ada87a 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -354,6 +354,7 @@ class MambaPreTrainedModel(PreTrainedModel): base_model_prefix = "backbone" _no_split_modules = ["MambaBlock"] supports_gradient_checkpointing = True + _is_stateful = True def _init_weights(self, module): """Initialize the weights.""" diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index d9e4bfadf32..8568bd999eb 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -394,6 +394,7 @@ class RwkvPreTrainedModel(PreTrainedModel): _no_split_modules = ["RwkvBlock"] _keep_in_fp32_modules = ["time_decay", "time_first"] supports_gradient_checkpointing = True + _is_stateful = True def _init_weights(self, module): """Initialize the weights.""" diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 1981f5a6391..6215bc87edf 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -102,7 +102,11 @@ class GenerationTesterMixin: if isinstance(config.eos_token_id, int): config.eos_token_id = [config.eos_token_id] config.pad_token_id = config.eos_token_id[0] - attention_mask = torch.ones_like(input_ids, dtype=torch.long) + + if self.has_attentions: + attention_mask = torch.ones_like(input_ids, dtype=torch.long) + else: + attention_mask = None # It is important set set the eos_token_id to None to ensure that no sequences # shorter than `max_length` can be generated @@ -437,7 +441,7 @@ class GenerationTesterMixin: output_scores=True, output_logits=True, output_hidden_states=True, - output_attentions=True, + output_attentions=self.has_attentions, return_dict_in_generate=True, ) @@ -471,7 +475,7 @@ class GenerationTesterMixin: output_scores=True, output_logits=True, output_hidden_states=True, - output_attentions=True, + output_attentions=self.has_attentions, return_dict_in_generate=True, ) @@ -529,7 +533,7 @@ class GenerationTesterMixin: output_scores=True, output_logits=True, output_hidden_states=True, - output_attentions=True, + output_attentions=self.has_attentions, return_dict_in_generate=True, ) @@ -595,7 +599,7 @@ class GenerationTesterMixin: output_scores=True, output_logits=True, output_hidden_states=True, - output_attentions=True, + output_attentions=self.has_attentions, return_dict_in_generate=True, ) if model.config.is_encoder_decoder: @@ -642,7 +646,7 @@ class GenerationTesterMixin: output_scores=True, output_logits=True, output_hidden_states=True, - output_attentions=True, + output_attentions=self.has_attentions, return_dict_in_generate=True, ) @@ -733,7 +737,7 @@ class GenerationTesterMixin: output_scores=True, output_logits=True, output_hidden_states=True, - output_attentions=True, + output_attentions=self.has_attentions, return_dict_in_generate=True, ) @@ -834,7 +838,7 @@ class GenerationTesterMixin: output_scores=True, output_logits=True, output_hidden_states=True, - output_attentions=True, + output_attentions=self.has_attentions, return_dict_in_generate=True, ) if model.config.is_encoder_decoder: @@ -952,7 +956,7 @@ class GenerationTesterMixin: output_scores=True, output_logits=True, output_hidden_states=True, - output_attentions=True, + output_attentions=self.has_attentions, return_dict_in_generate=True, ) @@ -973,6 +977,9 @@ class GenerationTesterMixin: def test_contrastive_generate(self): for model_class in self.all_generative_model_classes: + if model_class._is_stateful: + self.skipTest("Stateful models don't support contrastive search generation") + # won't fix: FSMT and Reformer have a different cache variable type (and format). if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): self.skipTest("Won't fix: old model with different cache format") @@ -997,6 +1004,9 @@ class GenerationTesterMixin: def test_contrastive_generate_dict_outputs_use_cache(self): for model_class in self.all_generative_model_classes: + if model_class._is_stateful: + self.skipTest("Stateful models don't support contrastive search generation") + # won't fix: FSMT and Reformer have a different cache variable type (and format). if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): self.skipTest("Won't fix: old model with different cache format") @@ -1017,7 +1027,7 @@ class GenerationTesterMixin: output_scores=True, output_logits=True, output_hidden_states=True, - output_attentions=True, + output_attentions=self.has_attentions, return_dict_in_generate=True, ) @@ -1030,9 +1040,12 @@ class GenerationTesterMixin: def test_contrastive_generate_low_memory(self): # Check that choosing 'low_memory' does not change the model output for model_class in self.all_generative_model_classes: + if model_class._is_stateful: + self.skipTest("Stateful models don't support contrastive search generation") + if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer", "speech2text"]): self.skipTest("Won't fix: old model with different cache format") - if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode", "jamba"]): + if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode"]): self.skipTest("TODO: fix me") config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1) @@ -1069,6 +1082,8 @@ class GenerationTesterMixin: def test_beam_search_low_memory(self): # Check that choosing 'low_memory' does not change the model output for model_class in self.all_generative_model_classes: + if model_class._is_stateful: + self.skipTest("May fix in the future: need custom cache handling") if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): self.skipTest("Won't fix: old model with different cache format") if any( @@ -1115,6 +1130,8 @@ class GenerationTesterMixin: # - assisted_decoding does not support `batch_size > 1` for model_class in self.all_generative_model_classes: + if model_class._is_stateful: + self.skipTest("Stateful models don't support assisted generation") if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): self.skipTest("Won't fix: old model with different cache format") if any( @@ -1156,7 +1173,7 @@ class GenerationTesterMixin: "output_scores": True, "output_logits": True, "output_hidden_states": True, - "output_attentions": True, + "output_attentions": self.has_attentions, "return_dict_in_generate": True, } output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) @@ -1184,6 +1201,8 @@ class GenerationTesterMixin: # This test is mostly a copy of test_assisted_decoding_matches_greedy_search for model_class in self.all_generative_model_classes: + if model_class._is_stateful: + self.skipTest("Stateful models don't support assisted generation") if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): self.skipTest("Won't fix: old model with different cache format") if any( @@ -1225,7 +1244,7 @@ class GenerationTesterMixin: "output_scores": True, "output_logits": True, "output_hidden_states": True, - "output_attentions": True, + "output_attentions": self.has_attentions, "return_dict_in_generate": True, } @@ -1244,6 +1263,8 @@ class GenerationTesterMixin: # match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with # different shapes, see https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535). for model_class in self.all_generative_model_classes: + if model_class._is_stateful: + self.skipTest("Stateful models don't support assisted generation") if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): self.skipTest("Won't fix: old model with different cache format") if any( @@ -1289,7 +1310,7 @@ class GenerationTesterMixin: "output_scores": True, "output_logits": True, "output_hidden_states": True, - "output_attentions": True, + "output_attentions": self.has_attentions, "return_dict_in_generate": True, } output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) @@ -1326,7 +1347,7 @@ class GenerationTesterMixin: input_ids, attention_mask=attention_mask, num_beams=1, - output_attentions=True, + output_attentions=self.has_attentions, return_dict_in_generate=True, remove_invalid_values=True, **{name: mask}, @@ -1344,6 +1365,10 @@ class GenerationTesterMixin: if len(self.all_generative_model_classes) == 0: self.skipTest(reason="No generative architecture available for this model.") + # - The model must support padding + if not self.has_attentions: + self.skipTest(reason="This model doesn't support padding.") + # - The model must be a decoder-only architecture (encoder-based architectures use right-padding) decoder_only_classes = [] for model_class in self.all_generative_model_classes: @@ -1704,30 +1729,31 @@ class GenerationTesterMixin: self._check_logits(num_sequences_in_output, output.logits, config=config) # Attentions - if config.is_encoder_decoder: - # encoder - self._check_encoder_attention_for_generate(output.encoder_attentions, batch_size, config, seq_length) - # decoder - self._check_attentions_for_generate( - num_sequences_in_output, - output.decoder_attentions, - min_length=1, - max_length=output.sequences.shape[-1], - config=config, - use_cache=use_cache, - ) - else: - # if use_cache first input is equal to no use_cache, so skip here - attentions = output.attentions if not use_cache else output.attentions[1:] - min_length = seq_length if not use_cache else seq_length + 1 - self._check_attentions_for_generate( - num_sequences_in_output, - attentions=attentions, - min_length=min_length, - max_length=output.sequences.shape[-1], - config=config, - use_cache=use_cache, - ) + if self.has_attentions: + if config.is_encoder_decoder: + # encoder + self._check_encoder_attention_for_generate(output.encoder_attentions, batch_size, config, seq_length) + # decoder + self._check_attentions_for_generate( + num_sequences_in_output, + output.decoder_attentions, + min_length=1, + max_length=output.sequences.shape[-1], + config=config, + use_cache=use_cache, + ) + else: + # if use_cache first input is equal to no use_cache, so skip here + attentions = output.attentions if not use_cache else output.attentions[1:] + min_length = seq_length if not use_cache else seq_length + 1 + self._check_attentions_for_generate( + num_sequences_in_output, + attentions=attentions, + min_length=min_length, + max_length=output.sequences.shape[-1], + config=config, + use_cache=use_cache, + ) # Hidden States if config.is_encoder_decoder: @@ -1763,7 +1789,7 @@ class GenerationTesterMixin: # 2. Some old models still return `output.past_key_values` even without `use_cache=True` # 3. TODO (joao): A few models have different formats/types, skipping those until the cache refactor is # complete - models_without_standard_cache = ("bloom", "ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba") + models_without_standard_cache = ("bloom", "ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba", "mamba") has_standard_cache = not any( model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache ) diff --git a/tests/models/jamba/test_modeling_jamba.py b/tests/models/jamba/test_modeling_jamba.py index 13208d54f1c..f69eb0d806b 100644 --- a/tests/models/jamba/test_modeling_jamba.py +++ b/tests/models/jamba/test_modeling_jamba.py @@ -503,10 +503,6 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi # They should result in very similar logits self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=3e-3)) - @unittest.skip("Jamba has its own special cache type") # FIXME: @gante - def test_assisted_decoding_matches_greedy_search_0_random(self): - pass - @require_flash_attn @require_torch_gpu @require_bitsandbytes diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 7aec7add111..1ddb8ad700b 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -250,6 +250,8 @@ class MambaModelTester: @require_torch class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (MambaModel, MambaForCausalLM) if is_torch_available() else () + all_generative_model_classes = (MambaForCausalLM,) if is_torch_available() else () + has_attentions = False # Mamba does not support attentions fx_compatible = False # FIXME let's try to support this @ArthurZucker test_torchscript = False # FIXME let's try to support this @ArthurZucker test_missing_keys = False @@ -292,10 +294,6 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi def test_config(self): self.config_tester.run_common_tests() - @unittest.skip("No attention in mamba") - def test_retain_grad_hidden_states_attentions(self): - pass - @require_torch_multi_gpu def test_multi_gpu_data_parallel_forward(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -364,14 +362,6 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi # check if it's a ones like self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5)) - @unittest.skip("Mamba does not use attention") - def test_attention_outputs(self): - r""" - Overriding the test_attention_outputs test as the attention outputs of Mamba are different from other models - it has a shape `batch_size, seq_len, hidden_size`. - """ - pass - @slow def test_model_from_pretrained(self): model = MambaModel.from_pretrained("hf-internal-testing/mamba-130m")