From 896833c183a168befcade879c1fca1889e20052d Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 23 May 2025 17:11:40 +0200 Subject: [PATCH] Fix some tests (especially compile with fullgraph=True on Python<3.11) (#38319) * fix tests * better fix for python<3.11 * fixes * style --- src/transformers/cache_utils.py | 2 +- src/transformers/integrations/executorch.py | 92 +++++++------------ src/transformers/masking_utils.py | 6 +- tests/models/cohere/test_modeling_cohere.py | 6 +- tests/models/csm/test_modeling_csm.py | 1 - tests/models/mixtral/test_modeling_mixtral.py | 23 +---- tests/test_modeling_common.py | 1 + 7 files changed, 48 insertions(+), 83 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 269c093aa6c..286d03a5598 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -2232,7 +2232,7 @@ class OffloadedHybridCache(HybridChunkedCache): def _prefetch_layer_in_context(self, layer_idx: int) -> None: """Performs the actual copy of the layer to device cache.""" - if len(self.key_cache) >= layer_idx: + if len(self.key_cache) > layer_idx: self.device_key_cache[self.active_device_layer].copy_(self.key_cache[layer_idx], non_blocking=True) self.device_value_cache[self.active_device_layer].copy_(self.value_cache[layer_idx], non_blocking=True) # The layer was not yet initialized diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index eb17dab55af..bd4b30a3d12 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -11,7 +11,6 @@ # specific language governing permissions and limitations under the License. import logging -from contextlib import contextmanager from typing import Callable, Optional import torch @@ -110,14 +109,13 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module): example_input_ids = input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long) example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long) - with patch_mask_interface(): - exported_program = torch.export.export( - self.model, - args=(example_input_ids, example_cache_position), - kwargs={}, - dynamic_shapes=dynamic_shapes, - strict=strict if strict is not None else True, - ) + exported_program = torch.export.export( + self.model, + args=(example_input_ids, example_cache_position), + kwargs={}, + dynamic_shapes=dynamic_shapes, + strict=strict if strict is not None else True, + ) return exported_program @staticmethod @@ -456,24 +454,6 @@ class TorchExportableModuleWithHybridCache(torch.nn.Module): return outputs.logits -@contextmanager -def patch_mask_interface(): - """ - Context manager to locally use a simple dict instead of `AttentionMaskInterface`, as otherwise export will fail - with `strict=True` due to dynamo skip rules, i.e. `torch._dynamo.exc.Unsupported: 'inline in skipfiles: - Mapping.__contains__ | __contains__, skipped according trace_rules.lookup SKIP_DIRS'`. - Note that this seem to be an issue only for python<3.11. - """ - import transformers - - original = transformers.masking_utils.ALL_MASK_ATTENTION_FUNCTIONS - transformers.masking_utils.ALL_MASK_ATTENTION_FUNCTIONS = ALL_MASK_ATTENTION_FUNCTIONS._global_mapping - try: - yield - finally: - transformers.masking_utils.ALL_MASK_ATTENTION_FUNCTIONS = original - - def convert_and_export_with_cache( model: PreTrainedModel, example_input_ids: Optional[torch.Tensor] = None, @@ -515,14 +495,13 @@ def convert_and_export_with_cache( ) if is_torch_greater_or_equal("2.6.0"): - with patch_mask_interface(): - exported_program = torch.export.export( - TorchExportableModuleWithStaticCache(model), - args=(example_input_ids, example_cache_position), - kwargs={}, - dynamic_shapes=dynamic_shapes, - strict=strict if strict is not None else True, - ) + exported_program = torch.export.export( + TorchExportableModuleWithStaticCache(model), + args=(example_input_ids, example_cache_position), + kwargs={}, + dynamic_shapes=dynamic_shapes, + strict=strict if strict is not None else True, + ) else: if dynamic_shapes is not None: logging.warning( @@ -534,14 +513,13 @@ def convert_and_export_with_cache( # # Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal # export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release. - with patch_mask_interface(): - exported_program = torch.export._trace._export( - TorchExportableModuleWithStaticCache(model), - args=(example_input_ids,), - kwargs={"cache_position": example_cache_position}, - pre_dispatch=False, - strict=True, - ) + exported_program = torch.export._trace._export( + TorchExportableModuleWithStaticCache(model), + args=(example_input_ids,), + kwargs={"cache_position": example_cache_position}, + pre_dispatch=False, + strict=True, + ) return exported_program @@ -634,10 +612,9 @@ class Seq2SeqLMExportableModule(torch.nn.Module): # Export the encoder with torch.no_grad(): - with patch_mask_interface(): - exported_encoder = torch.export.export( - wrapped_encoder, (encoder_input_ids,), dynamic_shapes={"input_ids": {1: seq_len_dim}}, strict=True - ) + exported_encoder = torch.export.export( + wrapped_encoder, (encoder_input_ids,), dynamic_shapes={"input_ids": {1: seq_len_dim}}, strict=True + ) return exported_encoder @@ -657,17 +634,16 @@ class Seq2SeqLMExportableModule(torch.nn.Module): # Export the decoder with torch.no_grad(): - with patch_mask_interface(): - exported_decoder = torch.export.export( - wrapped_decoder, - (decoder_input_ids, encoder_hidden_states, cache_position), - dynamic_shapes={ - "decoder_input_ids": None, - "encoder_hidden_states": {1: encoder_seq_len_dim}, - "cache_position": None, - }, - strict=True, - ) + exported_decoder = torch.export.export( + wrapped_decoder, + (decoder_input_ids, encoder_hidden_states, cache_position), + dynamic_shapes={ + "decoder_input_ids": None, + "encoder_hidden_states": {1: encoder_seq_len_dim}, + "cache_position": None, + }, + strict=True, + ) return exported_decoder diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index 36538882af5..53a81e1daaf 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -623,7 +623,11 @@ def _preprocess_mask_arguments( return True, attention_mask, None, None # For TGI/vLLM backends, or other custom attention without equivalent mask creation: we don't need a mask! - if config._attn_implementation not in ALL_MASK_ATTENTION_FUNCTIONS: + # Note: it's not ideal to check the `_global_mapping` attribute instead of the object itself, however otherwise + # full graph dynamo tracing (i.e. torch.export or compile with `fullgraph=True`) will fail on Python<3.11 + # with `torch._dynamo.exc.Unsupported: 'inline in skipfiles:Mapping.__contains__ | __contains__, skipped + # according trace_rules.lookup SKIP_DIRS'` -- can be removed when we require Python>=3.11 + if config._attn_implementation not in ALL_MASK_ATTENTION_FUNCTIONS._global_mapping: return True, None, None, None # Move the mask to correct device, and potentially switch dtype for efficiency diff --git a/tests/models/cohere/test_modeling_cohere.py b/tests/models/cohere/test_modeling_cohere.py index bebafedc7df..ff7963ae7e0 100644 --- a/tests/models/cohere/test_modeling_cohere.py +++ b/tests/models/cohere/test_modeling_cohere.py @@ -232,8 +232,8 @@ class CohereIntegrationTest(unittest.TestCase): EXPECTED_LOGITS = torch.Tensor( [ - [[0.0000, 0.1866, -0.1997], [0.0000, -0.0736, 0.1785], [0.0000, -0.1965, -0.0569]], - [[0.0000, -0.0302, 0.1488], [0.0000, -0.0402, 0.1351], [0.0000, -0.0341, 0.1116]], + [[0.0000, 0.0285, 0.0322], [0.0000, 0.0011, 0.1105], [0.0000, -0.0018, -0.1019]], + [[0.0000, 0.1080, 0.0454], [0.0000, -0.1808, -0.1553], [0.0000, 0.0452, 0.0369]], ] ).to(device=torch_device, dtype=torch.float16) @@ -251,4 +251,4 @@ class CohereIntegrationTest(unittest.TestCase): output = model(**inputs) logits = output.logits - torch.testing.assert_close(EXPECTED_LOGITS, logits[:, :3, :3], rtol=1e-3, atol=1e-3) + torch.testing.assert_close(EXPECTED_LOGITS, logits[:, -3:, :3], rtol=1e-3, atol=1e-3) diff --git a/tests/models/csm/test_modeling_csm.py b/tests/models/csm/test_modeling_csm.py index be4ab6a0e2a..26442ef8458 100644 --- a/tests/models/csm/test_modeling_csm.py +++ b/tests/models/csm/test_modeling_csm.py @@ -150,7 +150,6 @@ class CsmForConditionalGenerationTest(ModelTesterMixin, GenerationTesterMixin, u test_headmasking = False test_resize_embeddings = False test_resize_embeddings_untied = False - test_torch_exportable = True def setUp(self): self.model_tester = CsmModelTester(self) diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py index 2d7c95529be..8f4215c7205 100644 --- a/tests/models/mixtral/test_modeling_mixtral.py +++ b/tests/models/mixtral/test_modeling_mixtral.py @@ -402,24 +402,12 @@ class MixtralIntegrationTest(unittest.TestCase): # # Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s, # considering differences in hardware processing and potential deviations in generated text. - EXPECTED_LOGITS_LEFT = { - 7: torch.Tensor( - [[0.1904, 0.0500, 0.7187], [0.1933, 0.0515, 0.7187], [0.2001, 0.0559, 0.7148]], - ).to(torch_device), - 8: torch.Tensor([[0.1914, 0.0508, 0.7188], [0.1953, 0.0510, 0.7227], [0.1973, 0.0562, 0.7148]]).to( - torch_device - ), - 9: torch.Tensor([[0.1904, 0.0513, 0.7227], [0.1943, 0.0518, 0.7227], [0.1982, 0.0557, 0.7148]]).to( - torch_device - ), - } - EXPECTED_LOGITS_LEFT_UNPADDED = { 7: torch.Tensor( [[0.2236, 0.5195, -0.3828], [0.8203, -0.2275, 0.6054], [0.2656, -0.7070, 0.2460]], ).to(torch_device), - 8: torch.Tensor([[0.2217, 0.5195, -0.3828], [0.8203, -0.2295, 0.6055], [0.2676, -0.7109, 0.2461]]).to( - torch_device + 8: torch.Tensor([[0.2207, 0.5234, -0.3828], [0.8203, -0.2285, 0.6055], [0.2656, -0.7109, 0.2451]]).to( + torch_device, ), 9: torch.Tensor([[0.2236, 0.5195, -0.3828], [0.8203, -0.2285, 0.6055], [0.2637, -0.7109, 0.2451]]).to( torch_device @@ -430,8 +418,8 @@ class MixtralIntegrationTest(unittest.TestCase): 7: torch.Tensor([[0.2167, 0.1269, -0.1640], [-0.3496, 0.2988, -1.0312], [0.0688, 0.7929, 0.8007]]).to( torch_device ), - 8: torch.Tensor([[0.2178, 0.1260, -0.1621], [-0.3496, 0.2988, -1.0312], [0.0693, 0.7930, 0.8008]]).to( - torch_device + 8: torch.Tensor([[0.2178, 0.1270, -0.1621], [-0.3496, 0.3008, -1.0312], [0.0693, 0.7930, 0.7969]]).to( + torch_device, ), 9: torch.Tensor([[0.2197, 0.1250, -0.1611], [-0.3516, 0.3008, -1.0312], [0.0684, 0.7930, 0.8008]]).to( torch_device @@ -442,9 +430,6 @@ class MixtralIntegrationTest(unittest.TestCase): logits = model(dummy_input, attention_mask=attention_mask).logits logits = logits.float() - torch.testing.assert_close( - logits[0, :3, :3], EXPECTED_LOGITS_LEFT[self.cuda_compute_capability_major_version], atol=1e-3, rtol=1e-3 - ) torch.testing.assert_close( logits[0, -3:, -3:], EXPECTED_LOGITS_LEFT_UNPADDED[self.cuda_compute_capability_major_version], diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index a85c9e7e625..ddef77eef13 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4461,6 +4461,7 @@ class ModelTesterMixin: del loss model = torch.compile(model, fullgraph=True, mode="reduce-overhead") + # forward compilation set_seed(42) loss = model(**inputs).loss