mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Auto compile when static cache (#34247)
* generate with compile * nits * simple * generate with compile * nits * simple * safe * style * Update src/transformers/generation/utils.py Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co> * remove TOKENIZER forked warning --------- Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
This commit is contained in:
parent
d9e6f307e7
commit
597efd21d2
@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import inspect
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
@ -3224,6 +3225,16 @@ class GenerationMixin:
|
||||
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
||||
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||
|
||||
def model_forward(model, *args, **kwargs):
|
||||
return model.forward(*args, **kwargs)
|
||||
|
||||
if isinstance(model_kwargs.get("past_key_values"), StaticCache):
|
||||
if self.device.type == "cuda":
|
||||
logger.warning_once("Using `torch.compile`.")
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "0"
|
||||
model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
i = 0
|
||||
while self._has_unfinished_sequences(
|
||||
this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length
|
||||
):
|
||||
@ -3234,8 +3245,11 @@ class GenerationMixin:
|
||||
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
|
||||
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
|
||||
|
||||
# forward pass to get next token
|
||||
outputs = self(**model_inputs, return_dict=True)
|
||||
if i == 0:
|
||||
outputs = self(**model_inputs, return_dict=True)
|
||||
i += 1
|
||||
else:
|
||||
outputs = model_forward(self, return_dict=True, **model_inputs)
|
||||
|
||||
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
|
Loading…
Reference in New Issue
Block a user