Fix some tests (especially compile with fullgraph=True on Python<3.11) (#38319)

* fix tests

* better fix for python<3.11

* fixes

* style
This commit is contained in:
Cyril Vallez 2025-05-23 17:11:40 +02:00 committed by GitHub
parent a63bc17416
commit 896833c183
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 48 additions and 83 deletions

View File

@ -2232,7 +2232,7 @@ class OffloadedHybridCache(HybridChunkedCache):
def _prefetch_layer_in_context(self, layer_idx: int) -> None: def _prefetch_layer_in_context(self, layer_idx: int) -> None:
"""Performs the actual copy of the layer to device cache.""" """Performs the actual copy of the layer to device cache."""
if len(self.key_cache) >= layer_idx: if len(self.key_cache) > layer_idx:
self.device_key_cache[self.active_device_layer].copy_(self.key_cache[layer_idx], non_blocking=True) self.device_key_cache[self.active_device_layer].copy_(self.key_cache[layer_idx], non_blocking=True)
self.device_value_cache[self.active_device_layer].copy_(self.value_cache[layer_idx], non_blocking=True) self.device_value_cache[self.active_device_layer].copy_(self.value_cache[layer_idx], non_blocking=True)
# The layer was not yet initialized # The layer was not yet initialized

View File

@ -11,7 +11,6 @@
# specific language governing permissions and limitations under the License. # specific language governing permissions and limitations under the License.
import logging import logging
from contextlib import contextmanager
from typing import Callable, Optional from typing import Callable, Optional
import torch import torch
@ -110,14 +109,13 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
example_input_ids = input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long) example_input_ids = input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long)
example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long) example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long)
with patch_mask_interface(): exported_program = torch.export.export(
exported_program = torch.export.export( self.model,
self.model, args=(example_input_ids, example_cache_position),
args=(example_input_ids, example_cache_position), kwargs={},
kwargs={}, dynamic_shapes=dynamic_shapes,
dynamic_shapes=dynamic_shapes, strict=strict if strict is not None else True,
strict=strict if strict is not None else True, )
)
return exported_program return exported_program
@staticmethod @staticmethod
@ -456,24 +454,6 @@ class TorchExportableModuleWithHybridCache(torch.nn.Module):
return outputs.logits return outputs.logits
@contextmanager
def patch_mask_interface():
"""
Context manager to locally use a simple dict instead of `AttentionMaskInterface`, as otherwise export will fail
with `strict=True` due to dynamo skip rules, i.e. `torch._dynamo.exc.Unsupported: 'inline in skipfiles:
Mapping.__contains__ | __contains__, skipped according trace_rules.lookup SKIP_DIRS'`.
Note that this seem to be an issue only for python<3.11.
"""
import transformers
original = transformers.masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
transformers.masking_utils.ALL_MASK_ATTENTION_FUNCTIONS = ALL_MASK_ATTENTION_FUNCTIONS._global_mapping
try:
yield
finally:
transformers.masking_utils.ALL_MASK_ATTENTION_FUNCTIONS = original
def convert_and_export_with_cache( def convert_and_export_with_cache(
model: PreTrainedModel, model: PreTrainedModel,
example_input_ids: Optional[torch.Tensor] = None, example_input_ids: Optional[torch.Tensor] = None,
@ -515,14 +495,13 @@ def convert_and_export_with_cache(
) )
if is_torch_greater_or_equal("2.6.0"): if is_torch_greater_or_equal("2.6.0"):
with patch_mask_interface(): exported_program = torch.export.export(
exported_program = torch.export.export( TorchExportableModuleWithStaticCache(model),
TorchExportableModuleWithStaticCache(model), args=(example_input_ids, example_cache_position),
args=(example_input_ids, example_cache_position), kwargs={},
kwargs={}, dynamic_shapes=dynamic_shapes,
dynamic_shapes=dynamic_shapes, strict=strict if strict is not None else True,
strict=strict if strict is not None else True, )
)
else: else:
if dynamic_shapes is not None: if dynamic_shapes is not None:
logging.warning( logging.warning(
@ -534,14 +513,13 @@ def convert_and_export_with_cache(
# #
# Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal # Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal
# export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release. # export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release.
with patch_mask_interface(): exported_program = torch.export._trace._export(
exported_program = torch.export._trace._export( TorchExportableModuleWithStaticCache(model),
TorchExportableModuleWithStaticCache(model), args=(example_input_ids,),
args=(example_input_ids,), kwargs={"cache_position": example_cache_position},
kwargs={"cache_position": example_cache_position}, pre_dispatch=False,
pre_dispatch=False, strict=True,
strict=True, )
)
return exported_program return exported_program
@ -634,10 +612,9 @@ class Seq2SeqLMExportableModule(torch.nn.Module):
# Export the encoder # Export the encoder
with torch.no_grad(): with torch.no_grad():
with patch_mask_interface(): exported_encoder = torch.export.export(
exported_encoder = torch.export.export( wrapped_encoder, (encoder_input_ids,), dynamic_shapes={"input_ids": {1: seq_len_dim}}, strict=True
wrapped_encoder, (encoder_input_ids,), dynamic_shapes={"input_ids": {1: seq_len_dim}}, strict=True )
)
return exported_encoder return exported_encoder
@ -657,17 +634,16 @@ class Seq2SeqLMExportableModule(torch.nn.Module):
# Export the decoder # Export the decoder
with torch.no_grad(): with torch.no_grad():
with patch_mask_interface(): exported_decoder = torch.export.export(
exported_decoder = torch.export.export( wrapped_decoder,
wrapped_decoder, (decoder_input_ids, encoder_hidden_states, cache_position),
(decoder_input_ids, encoder_hidden_states, cache_position), dynamic_shapes={
dynamic_shapes={ "decoder_input_ids": None,
"decoder_input_ids": None, "encoder_hidden_states": {1: encoder_seq_len_dim},
"encoder_hidden_states": {1: encoder_seq_len_dim}, "cache_position": None,
"cache_position": None, },
}, strict=True,
strict=True, )
)
return exported_decoder return exported_decoder

View File

@ -623,7 +623,11 @@ def _preprocess_mask_arguments(
return True, attention_mask, None, None return True, attention_mask, None, None
# For TGI/vLLM backends, or other custom attention without equivalent mask creation: we don't need a mask! # For TGI/vLLM backends, or other custom attention without equivalent mask creation: we don't need a mask!
if config._attn_implementation not in ALL_MASK_ATTENTION_FUNCTIONS: # Note: it's not ideal to check the `_global_mapping` attribute instead of the object itself, however otherwise
# full graph dynamo tracing (i.e. torch.export or compile with `fullgraph=True`) will fail on Python<3.11
# with `torch._dynamo.exc.Unsupported: 'inline in skipfiles:Mapping.__contains__ | __contains__, skipped
# according trace_rules.lookup SKIP_DIRS'` -- can be removed when we require Python>=3.11
if config._attn_implementation not in ALL_MASK_ATTENTION_FUNCTIONS._global_mapping:
return True, None, None, None return True, None, None, None
# Move the mask to correct device, and potentially switch dtype for efficiency # Move the mask to correct device, and potentially switch dtype for efficiency

View File

@ -232,8 +232,8 @@ class CohereIntegrationTest(unittest.TestCase):
EXPECTED_LOGITS = torch.Tensor( EXPECTED_LOGITS = torch.Tensor(
[ [
[[0.0000, 0.1866, -0.1997], [0.0000, -0.0736, 0.1785], [0.0000, -0.1965, -0.0569]], [[0.0000, 0.0285, 0.0322], [0.0000, 0.0011, 0.1105], [0.0000, -0.0018, -0.1019]],
[[0.0000, -0.0302, 0.1488], [0.0000, -0.0402, 0.1351], [0.0000, -0.0341, 0.1116]], [[0.0000, 0.1080, 0.0454], [0.0000, -0.1808, -0.1553], [0.0000, 0.0452, 0.0369]],
] ]
).to(device=torch_device, dtype=torch.float16) ).to(device=torch_device, dtype=torch.float16)
@ -251,4 +251,4 @@ class CohereIntegrationTest(unittest.TestCase):
output = model(**inputs) output = model(**inputs)
logits = output.logits logits = output.logits
torch.testing.assert_close(EXPECTED_LOGITS, logits[:, :3, :3], rtol=1e-3, atol=1e-3) torch.testing.assert_close(EXPECTED_LOGITS, logits[:, -3:, :3], rtol=1e-3, atol=1e-3)

View File

@ -150,7 +150,6 @@ class CsmForConditionalGenerationTest(ModelTesterMixin, GenerationTesterMixin, u
test_headmasking = False test_headmasking = False
test_resize_embeddings = False test_resize_embeddings = False
test_resize_embeddings_untied = False test_resize_embeddings_untied = False
test_torch_exportable = True
def setUp(self): def setUp(self):
self.model_tester = CsmModelTester(self) self.model_tester = CsmModelTester(self)

View File

@ -402,24 +402,12 @@ class MixtralIntegrationTest(unittest.TestCase):
# #
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s, # Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
# considering differences in hardware processing and potential deviations in generated text. # considering differences in hardware processing and potential deviations in generated text.
EXPECTED_LOGITS_LEFT = {
7: torch.Tensor(
[[0.1904, 0.0500, 0.7187], [0.1933, 0.0515, 0.7187], [0.2001, 0.0559, 0.7148]],
).to(torch_device),
8: torch.Tensor([[0.1914, 0.0508, 0.7188], [0.1953, 0.0510, 0.7227], [0.1973, 0.0562, 0.7148]]).to(
torch_device
),
9: torch.Tensor([[0.1904, 0.0513, 0.7227], [0.1943, 0.0518, 0.7227], [0.1982, 0.0557, 0.7148]]).to(
torch_device
),
}
EXPECTED_LOGITS_LEFT_UNPADDED = { EXPECTED_LOGITS_LEFT_UNPADDED = {
7: torch.Tensor( 7: torch.Tensor(
[[0.2236, 0.5195, -0.3828], [0.8203, -0.2275, 0.6054], [0.2656, -0.7070, 0.2460]], [[0.2236, 0.5195, -0.3828], [0.8203, -0.2275, 0.6054], [0.2656, -0.7070, 0.2460]],
).to(torch_device), ).to(torch_device),
8: torch.Tensor([[0.2217, 0.5195, -0.3828], [0.8203, -0.2295, 0.6055], [0.2676, -0.7109, 0.2461]]).to( 8: torch.Tensor([[0.2207, 0.5234, -0.3828], [0.8203, -0.2285, 0.6055], [0.2656, -0.7109, 0.2451]]).to(
torch_device torch_device,
), ),
9: torch.Tensor([[0.2236, 0.5195, -0.3828], [0.8203, -0.2285, 0.6055], [0.2637, -0.7109, 0.2451]]).to( 9: torch.Tensor([[0.2236, 0.5195, -0.3828], [0.8203, -0.2285, 0.6055], [0.2637, -0.7109, 0.2451]]).to(
torch_device torch_device
@ -430,8 +418,8 @@ class MixtralIntegrationTest(unittest.TestCase):
7: torch.Tensor([[0.2167, 0.1269, -0.1640], [-0.3496, 0.2988, -1.0312], [0.0688, 0.7929, 0.8007]]).to( 7: torch.Tensor([[0.2167, 0.1269, -0.1640], [-0.3496, 0.2988, -1.0312], [0.0688, 0.7929, 0.8007]]).to(
torch_device torch_device
), ),
8: torch.Tensor([[0.2178, 0.1260, -0.1621], [-0.3496, 0.2988, -1.0312], [0.0693, 0.7930, 0.8008]]).to( 8: torch.Tensor([[0.2178, 0.1270, -0.1621], [-0.3496, 0.3008, -1.0312], [0.0693, 0.7930, 0.7969]]).to(
torch_device torch_device,
), ),
9: torch.Tensor([[0.2197, 0.1250, -0.1611], [-0.3516, 0.3008, -1.0312], [0.0684, 0.7930, 0.8008]]).to( 9: torch.Tensor([[0.2197, 0.1250, -0.1611], [-0.3516, 0.3008, -1.0312], [0.0684, 0.7930, 0.8008]]).to(
torch_device torch_device
@ -442,9 +430,6 @@ class MixtralIntegrationTest(unittest.TestCase):
logits = model(dummy_input, attention_mask=attention_mask).logits logits = model(dummy_input, attention_mask=attention_mask).logits
logits = logits.float() logits = logits.float()
torch.testing.assert_close(
logits[0, :3, :3], EXPECTED_LOGITS_LEFT[self.cuda_compute_capability_major_version], atol=1e-3, rtol=1e-3
)
torch.testing.assert_close( torch.testing.assert_close(
logits[0, -3:, -3:], logits[0, -3:, -3:],
EXPECTED_LOGITS_LEFT_UNPADDED[self.cuda_compute_capability_major_version], EXPECTED_LOGITS_LEFT_UNPADDED[self.cuda_compute_capability_major_version],

View File

@ -4461,6 +4461,7 @@ class ModelTesterMixin:
del loss del loss
model = torch.compile(model, fullgraph=True, mode="reduce-overhead") model = torch.compile(model, fullgraph=True, mode="reduce-overhead")
# forward compilation # forward compilation
set_seed(42) set_seed(42)
loss = model(**inputs).loss loss = model(**inputs).loss