mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Test: generate with torch.compile(model.forward)
as a fast test (#34544)
This commit is contained in:
parent
f48ecd7608
commit
ece8c42488
@ -349,7 +349,7 @@ In case you are using Sink Cache, you have to crop your inputs to that maximum l
|
||||
>>> user_prompts = ["Hello, what's your name?", "Btw, yesterday I was on a rock concert."]
|
||||
|
||||
>>> past_key_values = DynamicCache()
|
||||
>>> max_cache_length = past_key_values.get_max_length()
|
||||
>>> max_cache_length = past_key_values.get_max_cache_shape()
|
||||
|
||||
>>> messages = []
|
||||
>>> for prompt in user_prompts:
|
||||
|
@ -29,6 +29,8 @@ class Cache(torch.nn.Module):
|
||||
Base, abstract class for all caches. The actual data structure is specific to each subclass.
|
||||
"""
|
||||
|
||||
is_compileable = False
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@ -1098,6 +1100,8 @@ class StaticCache(Cache):
|
||||
```
|
||||
"""
|
||||
|
||||
is_compileable = True
|
||||
|
||||
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
|
||||
@deprecate_kwarg("layer_device_map", version="4.52.0")
|
||||
def __init__(
|
||||
@ -1297,6 +1301,7 @@ class SlidingWindowCache(StaticCache):
|
||||
"""
|
||||
|
||||
is_sliding = True
|
||||
is_compileable = True
|
||||
|
||||
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
|
||||
def __init__(
|
||||
@ -1421,6 +1426,7 @@ class EncoderDecoderCache(Cache):
|
||||
super().__init__()
|
||||
self.self_attention_cache = self_attention_cache
|
||||
self.cross_attention_cache = cross_attention_cache
|
||||
self.is_compileable = getattr(self.self_attention_cache, "is_compileable", False)
|
||||
|
||||
self.is_updated = {}
|
||||
for layer_idx in range(len(cross_attention_cache.key_cache)):
|
||||
@ -1612,6 +1618,8 @@ class HybridCache(Cache):
|
||||
```
|
||||
"""
|
||||
|
||||
is_compileable = True
|
||||
|
||||
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
|
||||
@deprecate_kwarg("layer_device_map", version="4.52.0")
|
||||
def __init__(
|
||||
@ -1832,6 +1840,8 @@ class MambaCache:
|
||||
```
|
||||
"""
|
||||
|
||||
is_compileable = True
|
||||
|
||||
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
|
||||
def __init__(
|
||||
self,
|
||||
@ -1975,6 +1985,8 @@ class OffloadedStaticCache(StaticCache):
|
||||
```
|
||||
"""
|
||||
|
||||
is_compileable = True
|
||||
|
||||
@deprecate_kwarg("layer_device_map", version="4.52.0")
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -1579,7 +1579,7 @@ class SynthIDTextWatermarkingConfig(BaseWatermarkingConfig):
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompileConfig(object):
|
||||
class CompileConfig:
|
||||
"""
|
||||
Class that holds arguments relative to `torch.compile` behavior, when using automatic compilation in `generate`.
|
||||
See [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) for more details on the arguments.
|
||||
@ -1620,7 +1620,9 @@ class CompileConfig(object):
|
||||
backend: Union[str, Callable] = "inductor"
|
||||
mode: str = "reduce-overhead"
|
||||
options: Optional[dict] = None
|
||||
# Used to flag our `generate` call to compile on e.g. CPU. Often not optimal, but useful for testing purposes.
|
||||
_compile_all_devices = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serializes this instance to a Python dictionary."""
|
||||
return copy.deepcopy(self.__dict__)
|
||||
return copy.deepcopy({key: value for key, value in self.__dict__.items() if key != "_compile_all_devices"})
|
||||
|
@ -3177,9 +3177,11 @@ class GenerationMixin:
|
||||
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||
|
||||
model_forward = self.__call__
|
||||
if isinstance(model_kwargs.get("past_key_values"), StaticCache):
|
||||
if self.device.type == "cuda":
|
||||
logger.warning_once("Using `torch.compile`.")
|
||||
if isinstance(model_kwargs.get("past_key_values"), Cache):
|
||||
is_compileable = model_kwargs["past_key_values"].is_compileable
|
||||
if is_compileable and (
|
||||
self.device.type == "cuda" or generation_config.compile_config._compile_all_devices
|
||||
):
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "0"
|
||||
model_forward = self.get_compiled_call(generation_config.compile_config)
|
||||
|
||||
|
@ -708,7 +708,7 @@ class AriaPreTrainedModel(PreTrainedModel):
|
||||
_supports_flex_attn = True
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing)
|
||||
_supports_attention_backend = False
|
||||
|
||||
def _init_weights(self, module):
|
||||
@ -1561,6 +1561,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
logits_to_keep=logits_to_keep,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
logits = outputs[0]
|
||||
|
@ -1223,6 +1223,7 @@ class AriaTextPreTrainedModel(PreTrainedModel):
|
||||
|
||||
|
||||
class AriaPreTrainedModel(LlamaPreTrainedModel):
|
||||
_supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing)
|
||||
_supports_attention_backend = False
|
||||
|
||||
def _init_weights(self, module):
|
||||
@ -1535,6 +1536,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
logits_to_keep=logits_to_keep,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
logits = outputs[0]
|
||||
|
@ -833,6 +833,7 @@ class DbrxPreTrainedModel(PreTrainedModel):
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
|
||||
|
||||
def _init_weights(self, module: nn.Module):
|
||||
std = self.config.initializer_range
|
||||
|
@ -1802,6 +1802,7 @@ EMU3_INPUTS_DOCSTRING = r"""
|
||||
|
||||
class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["text_model.lm_head.weight"]
|
||||
_supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -1113,6 +1113,7 @@ class Emu3ForCausalLM(LlamaForCausalLM, Emu3PreTrainedModel, GenerationMixin):
|
||||
|
||||
class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["text_model.lm_head.weight"]
|
||||
_supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -52,7 +52,7 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel):
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_static_cache = False # TODO (fix me): compilation fails due to a stide error?
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
@ -843,6 +843,7 @@ class GraniteMoePreTrainedModel(PreTrainedModel):
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
|
@ -917,6 +917,7 @@ class IdeficsPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"]
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = False # IDEFICS cannot compile due to dynamic control flow when checking inputs
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of Idefics isn't meant for training from scratch - only
|
||||
|
@ -485,7 +485,7 @@ class MixtralPreTrainedModel(PreTrainedModel):
|
||||
_supports_flex_attn = True
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
|
@ -45,7 +45,9 @@ from ..mistral.modeling_mistral import (
|
||||
MistralForSequenceClassification,
|
||||
MistralForTokenClassification,
|
||||
MistralModel,
|
||||
MistralPreTrainedModel,
|
||||
MistralRMSNorm,
|
||||
MistralRotaryEmbedding,
|
||||
)
|
||||
from .configuration_mixtral import MixtralConfig
|
||||
|
||||
@ -313,6 +315,14 @@ class MixtralDecoderLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
class MixtralRotaryEmbedding(MistralRotaryEmbedding):
|
||||
pass
|
||||
|
||||
|
||||
class MixtralPreTrainedModel(MistralPreTrainedModel):
|
||||
_supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
|
||||
|
||||
|
||||
class MixtralModel(MistralModel):
|
||||
def __init__(self, config: MixtralConfig):
|
||||
super().__init__(config)
|
||||
|
@ -767,7 +767,7 @@ class OlmoePreTrainedModel(PreTrainedModel):
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
|
@ -912,7 +912,7 @@ class PhimoePreTrainedModel(PreTrainedModel):
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
|
@ -332,7 +332,7 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel):
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
_supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions`
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
|
@ -882,7 +882,7 @@ class Qwen2VLPreTrainedModel(PreTrainedModel):
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
_supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions`
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
|
@ -1978,52 +1978,82 @@ class GenerationTesterMixin:
|
||||
model.generate(**generation_kwargs, **inputs_dict)
|
||||
|
||||
@pytest.mark.generate
|
||||
@require_torch_accelerator
|
||||
@slow
|
||||
def test_generate_compile_model_forward(self):
|
||||
"""
|
||||
Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results. Tests
|
||||
end-to-end compilation and forward pass compilation only.
|
||||
Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results.
|
||||
⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️
|
||||
"""
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_static_cache:
|
||||
self.skipTest("This model doesn't support static cache")
|
||||
self.skipTest("This model doesn't support static cache (= no expectations of compilation support)")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=4)
|
||||
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval() # otherwise `self.training` is `True` -- this flag is used at attn mask creation time
|
||||
|
||||
input_ids = inputs_dict["input_ids"].to(torch_device)
|
||||
main_input = inputs_dict[model.main_input_name].to(torch_device)
|
||||
# creates two sets of *different* inputs with the same shape
|
||||
half_batch_size = input_ids.shape[0] // 2
|
||||
input_ids_sets = [input_ids[:half_batch_size, :], input_ids[half_batch_size : half_batch_size * 2, :]]
|
||||
self.assertTrue(input_ids_sets[0].shape == input_ids_sets[1].shape)
|
||||
half_batch_size = main_input.shape[0] // 2
|
||||
input_1 = {}
|
||||
input_2 = {}
|
||||
for key, value in inputs_dict.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
input_1[key] = value[:half_batch_size, :].to(torch_device)
|
||||
input_2[key] = value[half_batch_size : half_batch_size * 2, :].to(torch_device)
|
||||
else:
|
||||
input_1[key] = value
|
||||
input_2[key] = value
|
||||
model_input_sets = [input_1, input_2]
|
||||
self.assertTrue(
|
||||
model_input_sets[0][model.main_input_name].shape == model_input_sets[1][model.main_input_name].shape
|
||||
)
|
||||
|
||||
# compilation-specific setup
|
||||
torch.compiler.reset() # prevent cached compilation from being used in the test
|
||||
has_defined_cache_implementation = model.generation_config.cache_implementation is not None
|
||||
model.generation_config.compile_config._compile_all_devices = True # force compilation (e.g. fast CI, CPU)
|
||||
|
||||
generation_kwargs = {
|
||||
"do_sample": False,
|
||||
"max_new_tokens": 10,
|
||||
"max_new_tokens": 5,
|
||||
"return_dict_in_generate": True,
|
||||
"output_scores": True,
|
||||
"cache_implementation": "static",
|
||||
}
|
||||
|
||||
# get eager + dynamic cache results for future comparison
|
||||
dynamic_outputs = []
|
||||
for model_inputs in input_ids_sets:
|
||||
dynamic_outputs.append(model.generate(model_inputs, **generation_kwargs))
|
||||
for model_inputs in model_input_sets:
|
||||
gen_out = model.generate(**model_inputs, **generation_kwargs)
|
||||
dynamic_outputs.append(gen_out)
|
||||
# sanity checks for the default cache implementation
|
||||
if not has_defined_cache_implementation:
|
||||
decoder_cache = (
|
||||
gen_out.past_key_values.self_attention_cache
|
||||
if config.is_encoder_decoder
|
||||
else gen_out.past_key_values
|
||||
)
|
||||
self.assertTrue(isinstance(decoder_cache, DynamicCache))
|
||||
self.assertFalse(decoder_cache.is_compileable)
|
||||
self.assertFalse(hasattr(model, "_compiled_call")) # our auto compile should NOT have been called
|
||||
|
||||
# get compiled results
|
||||
generation_config = copy.deepcopy(model.generation_config)
|
||||
generation_config.update(**generation_kwargs)
|
||||
torch.compiler.reset()
|
||||
|
||||
model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead")
|
||||
# get compiled results -- relies on the automatic compilation triggered by specific "cache_implementation"
|
||||
if not has_defined_cache_implementation:
|
||||
generation_kwargs["cache_implementation"] = "static"
|
||||
|
||||
compiled_outputs = []
|
||||
for model_inputs in input_ids_sets:
|
||||
compiled_outputs.append(model.generate(model_inputs, generation_config=generation_config))
|
||||
for model_inputs in model_input_sets:
|
||||
gen_out = model.generate(**model_inputs, **generation_kwargs)
|
||||
compiled_outputs.append(gen_out)
|
||||
# sanity checks
|
||||
decoder_cache = (
|
||||
gen_out.past_key_values.self_attention_cache
|
||||
if config.is_encoder_decoder
|
||||
else gen_out.past_key_values
|
||||
)
|
||||
self.assertFalse(isinstance(decoder_cache, DynamicCache))
|
||||
self.assertTrue(decoder_cache.is_compileable)
|
||||
self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called
|
||||
|
||||
for dynamic_result, compiled_result in zip(dynamic_outputs, compiled_outputs):
|
||||
self._check_similar_generate_outputs(dynamic_result, compiled_result)
|
||||
|
@ -331,11 +331,6 @@ class ChameleonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
def test_batching_equivalence(self):
|
||||
pass
|
||||
|
||||
# TODO (joao, raushan): fix me -- the problem is in `cache_position[0] == 0`, i.e. dynamic control flow
|
||||
@unittest.skip("Chameleon is not compatible with end-to-end generation compilation")
|
||||
def test_generate_compile_model_forward(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class ChameleonIntegrationTest(unittest.TestCase):
|
||||
|
@ -368,10 +368,6 @@ class DbrxModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
def test_disk_offload_bin(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Dbrx does not support `torch.compile` with `fullgraph=True`.")
|
||||
def test_generate_compile_model_forward(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class DbrxModelIntegrationTest(unittest.TestCase):
|
||||
|
@ -780,10 +780,6 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, uni
|
||||
def test_custom_4d_attention_mask(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs")
|
||||
def test_generate_compile_model_forward(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="We only test the model that takes in multiple images")
|
||||
def test_model(self):
|
||||
pass
|
||||
|
@ -332,10 +332,6 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Can't compile fullgraph due to dynamic control flow in `prepare_inputs_for_generate`")
|
||||
def test_generate_compile_model_forward(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class Qwen2VLIntegrationTest(unittest.TestCase):
|
||||
|
@ -1602,6 +1602,11 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
with self.assertRaises(ValueError):
|
||||
model(input_features=input_features, labels=labels)
|
||||
|
||||
# TODO (joao, eustache): fix me :)
|
||||
@unittest.skip(reason="Whisper's custom generate is not consistent regarding the cache return types")
|
||||
def test_generate_compile_model_forward(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
|
@ -364,7 +364,7 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
input_ids = gen_out
|
||||
|
||||
# We went well beyond the cache length
|
||||
self.assertTrue(input_ids.shape[1] > cache.get_max_length() * 1.5)
|
||||
self.assertTrue(input_ids.shape[1] > cache.get_max_cache_shape() * 1.5)
|
||||
|
||||
# And it still produces a coherent english
|
||||
decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
|
||||
|
Loading…
Reference in New Issue
Block a user