mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00

* refactor: benchmarks Based on a discussion with @LysandreJik & @ArthurZucker, the goal of this PR is to improve transformers' benchmark system. This is a WIP, for the moment the infrastructure required to make things work is not ready. Will update the PR description when it is the case. * feat: add db init in benchmarks CI * fix: pg_config is missing in runner * fix: add psql to the runner * fix: connect info from env vars + PR comments * refactor: set database as env var * fix: invalid working directory * fix: `commit_msg` -> `commit_message` * fix: git marking checked out repo as unsafe * feat: add logging * fix: invalid device * feat: update grafana dashboard for prod grafana * feat: add `commit_id` to header table * feat: commit latest version of dashboard * feat: move measurements into json field * feat: remove drop table migration queries * fix: `torch.arrange` -> `torch.arange` * fix: add missing `s` to `cache_position` positional argument * fix: change model * revert: `cache_positions` -> `cache_position` * fix: set device for `StaticCache` * fix: set `StaticCache` dtype * feat: limit max cache len * fix script * raise error on failure! * not try catch * try to skip generate compilation * update * update docker image! * update * update again!@ * update * updates * ??? * ?? * use `torch.cuda.synchronize()` * fix json * nits * fix * fixed! * f**k * feat: add TTNT panels * feat: add try except --------- Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
405 lines
16 KiB
Python
405 lines
16 KiB
Python
import argparse
|
|
import json
|
|
import logging
|
|
import os
|
|
import sys
|
|
from statistics import mean
|
|
from threading import Event, Thread
|
|
from time import perf_counter, sleep
|
|
from typing import Optional
|
|
import gpustat
|
|
import psutil
|
|
import psycopg2
|
|
import torch
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StaticCache
|
|
from psycopg2.extras import Json
|
|
from psycopg2.extensions import register_adapter
|
|
|
|
|
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.INFO)
|
|
|
|
handler = logging.StreamHandler(sys.stdout)
|
|
handler.setLevel(logging.INFO)
|
|
formatter = logging.Formatter("[%(levelname)s - %(asctime)s] %(message)s")
|
|
handler.setFormatter(formatter)
|
|
logger.addHandler(handler)
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "1"
|
|
torch.set_float32_matmul_precision("high")
|
|
register_adapter(dict, Json)
|
|
|
|
|
|
def parse_arguments():
|
|
"""
|
|
Parse command line arguments for the benchmarking CLI.
|
|
"""
|
|
parser = argparse.ArgumentParser(description="CLI for benchmarking the huggingface/transformers.")
|
|
|
|
parser.add_argument(
|
|
"branch",
|
|
type=str,
|
|
help="The branch name on which the benchmarking is performed.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"commit_id",
|
|
type=str,
|
|
help="The commit hash on which the benchmarking is performed.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"commit_msg",
|
|
type=str,
|
|
help="The commit message associated with the commit, truncated to 70 characters.",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
return args.branch, args.commit_id, args.commit_msg
|
|
|
|
|
|
def collect_metrics(benchmark_id, continue_metric_collection):
|
|
p = psutil.Process(os.getpid())
|
|
conn = psycopg2.connect("dbname=metrics")
|
|
cur = conn.cursor()
|
|
while not continue_metric_collection.is_set():
|
|
with p.oneshot():
|
|
cpu_util = p.cpu_percent()
|
|
mem_megabytes = p.memory_info().rss / (1024 * 1024)
|
|
gpu_stats = gpustat.GPUStatCollection.new_query()
|
|
gpu_util = gpu_stats[0]["utilization.gpu"]
|
|
gpu_mem_megabytes = gpu_stats[0]["memory.used"]
|
|
cur.execute(
|
|
"INSERT INTO device_measurements (benchmark_id, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes) VALUES (%s, %s, %s, %s, %s)",
|
|
(benchmark_id, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes),
|
|
)
|
|
sleep(0.01)
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
|
|
def run_benchmark(branch: str, commit_id: str, commit_msg: str, num_tokens_to_generate=100):
|
|
continue_metric_collection = Event()
|
|
metrics_thread = None
|
|
try:
|
|
gpu_stats = gpustat.GPUStatCollection.new_query()
|
|
gpu_name = gpu_stats[0]["name"]
|
|
conn = psycopg2.connect("dbname=metrics")
|
|
cur = conn.cursor()
|
|
cur.execute(
|
|
"INSERT INTO benchmarks (branch, commit_id, commit_message, gpu_name) VALUES (%s, %s, %s, %s) RETURNING benchmark_id",
|
|
(branch, commit_id, commit_msg, gpu_name),
|
|
)
|
|
conn.commit()
|
|
benchmark_id = cur.fetchone()[0]
|
|
metrics_thread = Thread(target=collect_metrics, args=[benchmark_id, continue_metric_collection])
|
|
metrics_thread.start()
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" # silence warnings when compiling
|
|
|
|
device = "cuda"
|
|
ckpt = "meta-llama/Llama-2-7b-hf"
|
|
|
|
# This is to avoid counting download in model load time measurement
|
|
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16)
|
|
gen_config = GenerationConfig(do_sample=False, top_p=1, temperature=1)
|
|
start = perf_counter()
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
ckpt, torch_dtype=torch.float16, generation_config=gen_config
|
|
).eval()
|
|
model.to(device)
|
|
torch.cuda.synchronize()
|
|
end = perf_counter()
|
|
model_load_time = end - start
|
|
logger.info(f"loaded model in: {model_load_time}s")
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(ckpt)
|
|
|
|
prompt = "Why dogs are so cute?"
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
|
|
|
# Specify the max length (including both the prompt and the response)
|
|
# When calling `generate` with `cache_implementation="static" later, this is also used to create a `StaticCache` object
|
|
# with sequence length = `max_length`. The longer the more you will re-use it
|
|
seq_length = inputs["input_ids"].shape[1]
|
|
model.generation_config.max_length = seq_length + num_tokens_to_generate
|
|
batch_size = inputs["input_ids"].shape[0]
|
|
|
|
# Copied from the gpt-fast repo
|
|
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
|
|
q = torch.empty_like(probs_sort).exponential_(1)
|
|
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
|
|
|
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
|
|
logits = logits / max(temperature, 1e-5)
|
|
|
|
if top_k is not None:
|
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
|
pivot = v.select(-1, -1).unsqueeze(-1)
|
|
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
|
probs = torch.nn.functional.softmax(logits, dim=-1)
|
|
return probs
|
|
|
|
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
|
|
probs = logits_to_probs(logits[:, -1], temperature, top_k)
|
|
idx_next = multinomial_sample_one_no_sync(probs)
|
|
return idx_next, probs
|
|
|
|
def decode_one_token(model, cur_token, cache_position, past_key_values):
|
|
logits = model(
|
|
cur_token,
|
|
cache_position=cache_position,
|
|
past_key_values=past_key_values,
|
|
return_dict=False,
|
|
use_cache=True,
|
|
)[0]
|
|
new_token = sample(logits, temperature=0.6, top_k=5)[0]
|
|
return new_token
|
|
|
|
#########
|
|
# Eager #
|
|
#########
|
|
with torch.no_grad():
|
|
past_key_values = StaticCache(
|
|
model.config,
|
|
batch_size=batch_size,
|
|
device=device,
|
|
dtype=torch.float16,
|
|
max_cache_len=seq_length + num_tokens_to_generate,
|
|
)
|
|
cache_position = torch.arange(seq_length, device=device)
|
|
start = perf_counter()
|
|
model(
|
|
**inputs,
|
|
cache_position=cache_position,
|
|
past_key_values=past_key_values,
|
|
return_dict=False,
|
|
use_cache=True,
|
|
)
|
|
end = perf_counter()
|
|
first_eager_fwd_pass_time = end - start
|
|
logger.info(f"completed first eager fwd pass in: {first_eager_fwd_pass_time}s")
|
|
start = perf_counter()
|
|
output = model.generate(**inputs, do_sample=False)
|
|
end = perf_counter()
|
|
first_eager_generate_time = end - start
|
|
logger.info(f"completed first eager generation in: {first_eager_generate_time}s")
|
|
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")
|
|
|
|
past_key_values = StaticCache(
|
|
model.config,
|
|
batch_size=batch_size,
|
|
device=device,
|
|
dtype=torch.float16,
|
|
max_cache_len=seq_length + num_tokens_to_generate,
|
|
)
|
|
cache_position = torch.arange(seq_length, device=device)
|
|
start = perf_counter()
|
|
model(
|
|
**inputs,
|
|
cache_position=cache_position,
|
|
past_key_values=past_key_values,
|
|
return_dict=False,
|
|
use_cache=True,
|
|
)
|
|
end = perf_counter()
|
|
second_eager_fwd_pass_time = end - start
|
|
logger.info(f"completed second eager fwd pass in: {second_eager_fwd_pass_time}s")
|
|
start = perf_counter()
|
|
model.generate(**inputs, do_sample=False)
|
|
end = perf_counter()
|
|
second_eager_generate_time = end - start
|
|
logger.info(f"completed second eager generation in: {second_eager_generate_time}s")
|
|
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")
|
|
|
|
torch.compiler.reset()
|
|
|
|
################
|
|
# Forward pass #
|
|
################
|
|
|
|
# `torch.compile(model, ...)` is not recommended as you compile callbacks
|
|
# and full generate. We recommend compiling only the forward for now.
|
|
# "reduce-overhead" will use cudagraphs.
|
|
generated_ids = torch.zeros(
|
|
(batch_size, num_tokens_to_generate + seq_length), dtype=torch.int, device=device
|
|
)
|
|
|
|
generated_ids[:, :seq_length] = inputs["input_ids"]
|
|
decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True)
|
|
# model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
|
|
# TODO use decode_one_token(model, input_id.clone(), cache_position) for verification
|
|
past_key_values = StaticCache(
|
|
model.config,
|
|
batch_size=batch_size,
|
|
device=device,
|
|
dtype=torch.float16,
|
|
max_cache_len=seq_length + num_tokens_to_generate + 10,
|
|
)
|
|
cache_position = torch.arange(seq_length, device=device)
|
|
all_generated_tokens = []
|
|
### First compile, prefill
|
|
start = perf_counter()
|
|
next_token = decode_one_token(
|
|
model, inputs["input_ids"], cache_position=cache_position, past_key_values=past_key_values
|
|
)
|
|
torch.cuda.synchronize()
|
|
end = perf_counter()
|
|
time_to_first_token = end - start
|
|
logger.info(f"completed first compile generation in: {time_to_first_token}s")
|
|
cache_position += 1
|
|
all_generated_tokens += next_token.clone().detach().cpu().tolist()
|
|
|
|
cache_position = torch.tensor([seq_length], device=device)
|
|
### First compile, decoding
|
|
start = perf_counter()
|
|
next_token = decode_one_token(
|
|
model, next_token.clone(), cache_position=cache_position, past_key_values=past_key_values
|
|
)
|
|
torch.cuda.synchronize()
|
|
end = perf_counter()
|
|
time_to_second_token = end - start
|
|
logger.info(f"completed second compile generation in: {time_to_first_token}s")
|
|
cache_position += 1
|
|
all_generated_tokens += next_token.clone().detach().cpu().tolist()
|
|
|
|
### Second compile, decoding
|
|
start = perf_counter()
|
|
next_token = decode_one_token(
|
|
model, next_token.clone(), cache_position=cache_position, past_key_values=past_key_values
|
|
)
|
|
torch.cuda.synchronize()
|
|
end = perf_counter()
|
|
time_to_third_token = end - start
|
|
logger.info(f"completed third compile forward in: {time_to_first_token}s")
|
|
cache_position += 1
|
|
all_generated_tokens += next_token.clone().detach().cpu().tolist()
|
|
|
|
### Using cuda graphs decoding
|
|
|
|
start = perf_counter()
|
|
for _ in range(1, num_tokens_to_generate):
|
|
all_generated_tokens += next_token.clone().detach().cpu().tolist()
|
|
next_token = decode_one_token(
|
|
model, next_token.clone(), cache_position=cache_position, past_key_values=past_key_values
|
|
)
|
|
cache_position += 1
|
|
torch.cuda.synchronize()
|
|
end = perf_counter()
|
|
mean_time_to_next_token = (end - start) / num_tokens_to_generate
|
|
logger.info(f"completed next compile generation in: {mean_time_to_next_token}s")
|
|
logger.info(f"generated: {tokenizer.batch_decode(all_generated_tokens)}")
|
|
|
|
####################
|
|
# Generate compile #
|
|
####################
|
|
torch.compiler.reset()
|
|
# we will not compile full generate as it' s to intensive, tho we measure full forward!
|
|
|
|
past_key_values = StaticCache(
|
|
model.config,
|
|
batch_size=batch_size,
|
|
device=device,
|
|
dtype=torch.float16,
|
|
max_cache_len=seq_length + 128,
|
|
)
|
|
|
|
# 1st call
|
|
start = perf_counter()
|
|
output = model.generate(**inputs, past_key_values=past_key_values)
|
|
torch.cuda.synchronize()
|
|
end = perf_counter()
|
|
first_compile_generate_time = end - start
|
|
logger.info(f"completed first compile generation in: {first_compile_generate_time}s")
|
|
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")
|
|
|
|
past_key_values = StaticCache(
|
|
model.config,
|
|
batch_size=batch_size,
|
|
device=device,
|
|
dtype=torch.float16,
|
|
max_cache_len=seq_length + 128,
|
|
)
|
|
# 2nd call
|
|
start = perf_counter()
|
|
output = model.generate(**inputs, past_key_values=past_key_values)
|
|
torch.cuda.synchronize()
|
|
end = perf_counter()
|
|
second_compile_generate_time = end - start
|
|
logger.info(f"completed second compile generation in: {second_compile_generate_time}s")
|
|
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")
|
|
|
|
past_key_values = StaticCache(
|
|
model.config,
|
|
batch_size=batch_size,
|
|
device=device,
|
|
dtype=torch.float16,
|
|
max_cache_len=seq_length + 128,
|
|
)
|
|
|
|
# 3nd call
|
|
start = perf_counter()
|
|
output = model.generate(**inputs, past_key_values=past_key_values)
|
|
end = perf_counter()
|
|
third_compile_generate_time = end - start
|
|
logger.info(f"completed second compile generation in: {third_compile_generate_time}s")
|
|
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")
|
|
|
|
past_key_values = StaticCache(
|
|
model.config,
|
|
batch_size=batch_size,
|
|
device=device,
|
|
dtype=torch.float16,
|
|
max_cache_len=seq_length + 128,
|
|
)
|
|
# 4th call
|
|
start = perf_counter()
|
|
output = model.generate(**inputs, past_key_values=past_key_values)
|
|
end = perf_counter()
|
|
fourth_compile_generate_time = end - start
|
|
logger.info(f"completed second compile generation in: {fourth_compile_generate_time}s")
|
|
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")
|
|
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO model_measurements (
|
|
benchmark_id,
|
|
measurements
|
|
) VALUES (%s, %s)
|
|
""",
|
|
(
|
|
benchmark_id,
|
|
{
|
|
"model_load_time": model_load_time,
|
|
"first_eager_forward_pass_time_secs": first_eager_fwd_pass_time,
|
|
"second_eager_forward_pass_time_secs": second_eager_fwd_pass_time,
|
|
"first_eager_generate_time_secs": first_eager_generate_time,
|
|
"second_eager_generate_time_secs": second_eager_generate_time,
|
|
"time_to_first_token_secs": time_to_first_token,
|
|
"time_to_second_token_secs": time_to_second_token,
|
|
"time_to_third_token_secs": time_to_third_token,
|
|
"time_to_next_token_mean_secs": mean_time_to_next_token,
|
|
"first_compile_generate_time_secs": first_compile_generate_time,
|
|
"second_compile_generate_time_secs": second_compile_generate_time,
|
|
"third_compile_generate_time_secs": third_compile_generate_time,
|
|
"fourth_compile_generate_time_secs": fourth_compile_generate_time,
|
|
},
|
|
),
|
|
)
|
|
conn.commit()
|
|
conn.close()
|
|
except Exception as e:
|
|
logger.error(f"Caught exception: {e}")
|
|
continue_metric_collection.set()
|
|
if metrics_thread is not None:
|
|
metrics_thread.join()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
branch, commit_id, commit_msg = parse_arguments()
|
|
run_benchmark(branch, commit_id, commit_msg, num_tokens_to_generate=20)
|