Enable traced model for text-generation task (#22265)

This commit is contained in:
jiqing-feng 2023-03-22 18:19:26 +08:00 committed by GitHub
parent 0558914dff
commit 8472a224fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -20,6 +20,7 @@
import argparse
import logging
from typing import Tuple
import numpy as np
import torch
@ -27,6 +28,7 @@ import torch
from transformers import (
CTRLLMHeadModel,
CTRLTokenizer,
GenerationMixin,
GPT2LMHeadModel,
GPT2Tokenizer,
OpenAIGPTLMHeadModel,
@ -38,6 +40,7 @@ from transformers import (
XLNetLMHeadModel,
XLNetTokenizer,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
logging.basicConfig(
@ -151,6 +154,131 @@ def adjust_length_to_model(length, max_sequence_length):
return length
def sparse_model_config(model_config):
embedding_size = None
if hasattr(model_config, "hidden_size"):
embedding_size = model_config.hidden_size
elif hasattr(model_config, "n_embed"):
embedding_size = model_config.n_embed
elif hasattr(model_config, "n_embd"):
embedding_size = model_config.n_embd
num_head = None
if hasattr(model_config, "num_attention_heads"):
num_head = model_config.num_attention_heads
elif hasattr(model_config, "n_head"):
num_head = model_config.n_head
if embedding_size is None or num_head is None or num_head == 0:
raise ValueError("Check the model config")
num_embedding_size_per_head = int(embedding_size / num_head)
num_layer = model_config.n_layer
return num_layer, num_head, num_embedding_size_per_head
def prepare_jit_inputs(inputs, model, tokenizer):
num_batch = len(inputs)
dummy_input = tokenizer.batch_encode_plus(inputs, return_tensors="pt", padding=True)
num_block_layers, num_attention_heads, num_embedding_size_per_head = sparse_model_config(model.config)
if model.config.model_type == "bloom":
past_key_values = tuple(
(
torch.zeros(int(num_attention_heads * num_batch), num_embedding_size_per_head, 1)
.to(model.config.torch_dtype)
.to(model.device),
torch.zeros(int(num_attention_heads * num_batch), 1, num_embedding_size_per_head)
.to(model.config.torch_dtype)
.to(model.device),
)
for _ in range(num_block_layers)
)
else:
past_key_values = tuple(
(
torch.zeros(num_batch, num_attention_heads, 1, num_embedding_size_per_head)
.to(model.config.torch_dtype)
.to(model.device),
torch.zeros(num_batch, num_attention_heads, 1, num_embedding_size_per_head)
.to(model.config.torch_dtype)
.to(model.device),
)
for _ in range(num_block_layers)
)
dummy_input["attention_mask"] = torch.cat(
[
torch.zeros(dummy_input["attention_mask"].shape[0], 1).to(dummy_input["attention_mask"].dtype),
dummy_input["attention_mask"],
],
-1,
)
if model.config.use_cache:
jit_inputs = (
dummy_input["input_ids"].to(model.device),
past_key_values,
dummy_input["attention_mask"].to(model.device),
)
else:
jit_inputs = (
dummy_input["input_ids"].to(model.device),
dummy_input["attention_mask"].to(model.device),
)
return jit_inputs
class _ModelFallbackWrapper(GenerationMixin):
__slots__ = ("_optimized", "_default")
def __init__(self, optimized, default):
self._optimized = optimized
self._default = default
def __call__(self, *args, **kwargs):
if kwargs["past_key_values"] is None:
return self._default(*args, **kwargs)
trace_graph_inputs = []
kwargs.pop("position_ids", None)
for k, v in kwargs.items():
if v is not None and not isinstance(v, bool):
trace_graph_inputs.append(v)
trace_graph_inputs = tuple(trace_graph_inputs)
outputs = self._optimized(*trace_graph_inputs)
lm_logits = outputs[0]
past_key_values = outputs[1]
fixed_output = CausalLMOutputWithPast(
loss=None,
logits=lm_logits,
past_key_values=past_key_values,
hidden_states=None,
attentions=None,
)
return fixed_output
def __getattr__(self, item):
return getattr(self._default, item)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, inputs_embeds=None, use_cache=None, **kwargs
):
return self._default.prepare_inputs_for_generation(
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, **kwargs
)
def _reorder_cache(
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or
[`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.
"""
return self._default._reorder_cache(past_key_values, beam_idx)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
@ -196,6 +324,9 @@ def main():
action="store_true",
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
)
parser.add_argument(
"--jit", type=bool, default=False, help="Whether or not to use jit trace to accelerate inference"
)
args = parser.parse_args()
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
@ -213,6 +344,8 @@ def main():
raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = model_class.from_pretrained(args.model_name_or_path)
model.to(args.device)
@ -248,6 +381,18 @@ def main():
else:
input_ids = encoded_prompt
if args.jit:
jit_input_texts = ["jit"]
jit_inputs = prepare_jit_inputs(jit_input_texts, model, tokenizer)
torch._C._jit_set_texpr_fuser_enabled(False)
model.config.return_dict = False
traced_model = torch.jit.trace(model, jit_inputs, strict=False)
traced_model = torch.jit.freeze(traced_model.eval())
traced_model(*jit_inputs)
traced_model(*jit_inputs)
model = _ModelFallbackWrapper(traced_model, model)
output_sequences = model.generate(
input_ids=input_ids,
max_length=args.length + len(encoded_prompt[0]),