Auto compile when static cache (#34247)

* generate with compile

* nits

* simple

* generate with compile

* nits

* simple

* safe

* style

* Update src/transformers/generation/utils.py

Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>

* remove TOKENIZER forked warning

---------

Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
This commit is contained in:
Arthur 2024-11-22 15:33:35 +01:00 committed by GitHub
parent d9e6f307e7
commit 597efd21d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -15,6 +15,7 @@
# limitations under the License.
import copy
import inspect
import os
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
@ -3224,6 +3225,16 @@ class GenerationMixin:
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
def model_forward(model, *args, **kwargs):
return model.forward(*args, **kwargs)
if isinstance(model_kwargs.get("past_key_values"), StaticCache):
if self.device.type == "cuda":
logger.warning_once("Using `torch.compile`.")
os.environ["TOKENIZERS_PARALLELISM"] = "0"
model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True)
i = 0
while self._has_unfinished_sequences(
this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length
):
@ -3234,8 +3245,11 @@ class GenerationMixin:
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
# forward pass to get next token
outputs = self(**model_inputs, return_dict=True)
if i == 0:
outputs = self(**model_inputs, return_dict=True)
i += 1
else:
outputs = model_forward(self, return_dict=True, **model_inputs)
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(