mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Enable traced model for text-generation task (#22265)
This commit is contained in:
parent
0558914dff
commit
8472a224fb
@ -20,6 +20,7 @@
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -27,6 +28,7 @@ import torch
|
||||
from transformers import (
|
||||
CTRLLMHeadModel,
|
||||
CTRLTokenizer,
|
||||
GenerationMixin,
|
||||
GPT2LMHeadModel,
|
||||
GPT2Tokenizer,
|
||||
OpenAIGPTLMHeadModel,
|
||||
@ -38,6 +40,7 @@ from transformers import (
|
||||
XLNetLMHeadModel,
|
||||
XLNetTokenizer,
|
||||
)
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
@ -151,6 +154,131 @@ def adjust_length_to_model(length, max_sequence_length):
|
||||
return length
|
||||
|
||||
|
||||
def sparse_model_config(model_config):
|
||||
embedding_size = None
|
||||
if hasattr(model_config, "hidden_size"):
|
||||
embedding_size = model_config.hidden_size
|
||||
elif hasattr(model_config, "n_embed"):
|
||||
embedding_size = model_config.n_embed
|
||||
elif hasattr(model_config, "n_embd"):
|
||||
embedding_size = model_config.n_embd
|
||||
|
||||
num_head = None
|
||||
if hasattr(model_config, "num_attention_heads"):
|
||||
num_head = model_config.num_attention_heads
|
||||
elif hasattr(model_config, "n_head"):
|
||||
num_head = model_config.n_head
|
||||
|
||||
if embedding_size is None or num_head is None or num_head == 0:
|
||||
raise ValueError("Check the model config")
|
||||
|
||||
num_embedding_size_per_head = int(embedding_size / num_head)
|
||||
num_layer = model_config.n_layer
|
||||
|
||||
return num_layer, num_head, num_embedding_size_per_head
|
||||
|
||||
|
||||
def prepare_jit_inputs(inputs, model, tokenizer):
|
||||
num_batch = len(inputs)
|
||||
dummy_input = tokenizer.batch_encode_plus(inputs, return_tensors="pt", padding=True)
|
||||
num_block_layers, num_attention_heads, num_embedding_size_per_head = sparse_model_config(model.config)
|
||||
if model.config.model_type == "bloom":
|
||||
past_key_values = tuple(
|
||||
(
|
||||
torch.zeros(int(num_attention_heads * num_batch), num_embedding_size_per_head, 1)
|
||||
.to(model.config.torch_dtype)
|
||||
.to(model.device),
|
||||
torch.zeros(int(num_attention_heads * num_batch), 1, num_embedding_size_per_head)
|
||||
.to(model.config.torch_dtype)
|
||||
.to(model.device),
|
||||
)
|
||||
for _ in range(num_block_layers)
|
||||
)
|
||||
else:
|
||||
past_key_values = tuple(
|
||||
(
|
||||
torch.zeros(num_batch, num_attention_heads, 1, num_embedding_size_per_head)
|
||||
.to(model.config.torch_dtype)
|
||||
.to(model.device),
|
||||
torch.zeros(num_batch, num_attention_heads, 1, num_embedding_size_per_head)
|
||||
.to(model.config.torch_dtype)
|
||||
.to(model.device),
|
||||
)
|
||||
for _ in range(num_block_layers)
|
||||
)
|
||||
|
||||
dummy_input["attention_mask"] = torch.cat(
|
||||
[
|
||||
torch.zeros(dummy_input["attention_mask"].shape[0], 1).to(dummy_input["attention_mask"].dtype),
|
||||
dummy_input["attention_mask"],
|
||||
],
|
||||
-1,
|
||||
)
|
||||
|
||||
if model.config.use_cache:
|
||||
jit_inputs = (
|
||||
dummy_input["input_ids"].to(model.device),
|
||||
past_key_values,
|
||||
dummy_input["attention_mask"].to(model.device),
|
||||
)
|
||||
else:
|
||||
jit_inputs = (
|
||||
dummy_input["input_ids"].to(model.device),
|
||||
dummy_input["attention_mask"].to(model.device),
|
||||
)
|
||||
|
||||
return jit_inputs
|
||||
|
||||
|
||||
class _ModelFallbackWrapper(GenerationMixin):
|
||||
__slots__ = ("_optimized", "_default")
|
||||
|
||||
def __init__(self, optimized, default):
|
||||
self._optimized = optimized
|
||||
self._default = default
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if kwargs["past_key_values"] is None:
|
||||
return self._default(*args, **kwargs)
|
||||
trace_graph_inputs = []
|
||||
kwargs.pop("position_ids", None)
|
||||
for k, v in kwargs.items():
|
||||
if v is not None and not isinstance(v, bool):
|
||||
trace_graph_inputs.append(v)
|
||||
trace_graph_inputs = tuple(trace_graph_inputs)
|
||||
outputs = self._optimized(*trace_graph_inputs)
|
||||
lm_logits = outputs[0]
|
||||
past_key_values = outputs[1]
|
||||
fixed_output = CausalLMOutputWithPast(
|
||||
loss=None,
|
||||
logits=lm_logits,
|
||||
past_key_values=past_key_values,
|
||||
hidden_states=None,
|
||||
attentions=None,
|
||||
)
|
||||
return fixed_output
|
||||
|
||||
def __getattr__(self, item):
|
||||
return getattr(self._default, item)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, inputs_embeds=None, use_cache=None, **kwargs
|
||||
):
|
||||
return self._default.prepare_inputs_for_generation(
|
||||
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, **kwargs
|
||||
)
|
||||
|
||||
def _reorder_cache(
|
||||
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
|
||||
) -> Tuple[Tuple[torch.Tensor]]:
|
||||
"""
|
||||
This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or
|
||||
[`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
||||
beam_idx at every generation step.
|
||||
"""
|
||||
return self._default._reorder_cache(past_key_values, beam_idx)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
@ -196,6 +324,9 @@ def main():
|
||||
action="store_true",
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--jit", type=bool, default=False, help="Whether or not to use jit trace to accelerate inference"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
@ -213,6 +344,8 @@ def main():
|
||||
raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")
|
||||
|
||||
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model = model_class.from_pretrained(args.model_name_or_path)
|
||||
model.to(args.device)
|
||||
|
||||
@ -248,6 +381,18 @@ def main():
|
||||
else:
|
||||
input_ids = encoded_prompt
|
||||
|
||||
if args.jit:
|
||||
jit_input_texts = ["jit"]
|
||||
jit_inputs = prepare_jit_inputs(jit_input_texts, model, tokenizer)
|
||||
torch._C._jit_set_texpr_fuser_enabled(False)
|
||||
model.config.return_dict = False
|
||||
traced_model = torch.jit.trace(model, jit_inputs, strict=False)
|
||||
traced_model = torch.jit.freeze(traced_model.eval())
|
||||
traced_model(*jit_inputs)
|
||||
traced_model(*jit_inputs)
|
||||
|
||||
model = _ModelFallbackWrapper(traced_model, model)
|
||||
|
||||
output_sequences = model.generate(
|
||||
input_ids=input_ids,
|
||||
max_length=args.length + len(encoded_prompt[0]),
|
||||
|
Loading…
Reference in New Issue
Block a user