mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
refactor: benchmarks (#33896)
* refactor: benchmarks Based on a discussion with @LysandreJik & @ArthurZucker, the goal of this PR is to improve transformers' benchmark system. This is a WIP, for the moment the infrastructure required to make things work is not ready. Will update the PR description when it is the case. * feat: add db init in benchmarks CI * fix: pg_config is missing in runner * fix: add psql to the runner * fix: connect info from env vars + PR comments * refactor: set database as env var * fix: invalid working directory * fix: `commit_msg` -> `commit_message` * fix: git marking checked out repo as unsafe * feat: add logging * fix: invalid device * feat: update grafana dashboard for prod grafana * feat: add `commit_id` to header table * feat: commit latest version of dashboard * feat: move measurements into json field * feat: remove drop table migration queries * fix: `torch.arrange` -> `torch.arange` * fix: add missing `s` to `cache_position` positional argument * fix: change model * revert: `cache_positions` -> `cache_position` * fix: set device for `StaticCache` * fix: set `StaticCache` dtype * feat: limit max cache len * fix script * raise error on failure! * not try catch * try to skip generate compilation * update * update docker image! * update * update again!@ * update * updates * ??? * ?? * use `torch.cuda.synchronize()` * fix json * nits * fix * fixed! * f**k * feat: add TTNT panels * feat: add try except --------- Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
This commit is contained in:
parent
80bee7b114
commit
144852fb6b
73
.github/workflows/benchmark.yml
vendored
73
.github/workflows/benchmark.yml
vendored
@ -1,43 +1,72 @@
|
||||
name: Self-hosted runner (benchmark)
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "17 2 * * *"
|
||||
workflow_call:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
types: [ opened, labeled, reopened, synchronize ]
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
HF_HOME: /mnt/cache
|
||||
TF_FORCE_GPU_ALLOW_GROWTH: true
|
||||
|
||||
|
||||
jobs:
|
||||
benchmark:
|
||||
name: Benchmark
|
||||
runs-on:
|
||||
runs-on:
|
||||
group: aws-g5-4xlarge-cache
|
||||
container:
|
||||
image: huggingface/transformers-all-latest-gpu
|
||||
options: --gpus all --privileged --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
|
||||
image: huggingface/transformers-pytorch-gpu
|
||||
options: --gpus all --privileged --ipc host
|
||||
steps:
|
||||
- name: Update clone
|
||||
working-directory: /transformers
|
||||
- name: Get repo
|
||||
if: github.event_name == 'pull_request'
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
|
||||
- name: Get repo
|
||||
if: github.event_name == 'push'
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.sha }}
|
||||
|
||||
- name: Install libpq-dev & psql
|
||||
run: |
|
||||
git fetch && git checkout ${{ github.sha }}
|
||||
apt update
|
||||
apt install -y libpq-dev postgresql-client
|
||||
|
||||
- name: Install benchmark script dependencies
|
||||
run: python3 -m pip install -r benchmark/requirements.txt
|
||||
|
||||
- name: Reinstall transformers in edit mode (remove the one installed during docker image build)
|
||||
working-directory: /transformers
|
||||
run: python3 -m pip uninstall -y transformers && python3 -m pip install -e .
|
||||
run: python3 -m pip uninstall -y transformers && python3 -m pip install -e ".[torch]"
|
||||
|
||||
- name: Benchmark (daily)
|
||||
if: github.event_name == 'schedule'
|
||||
working-directory: /transformers
|
||||
- name: Run database init script
|
||||
run: |
|
||||
python3 -m pip install optimum-benchmark>=0.3.0
|
||||
HF_TOKEN=${{ secrets.TRANSFORMERS_BENCHMARK_TOKEN }} python3 benchmark/benchmark.py --repo_id hf-internal-testing/benchmark_results --path_in_repo $(date +'%Y-%m-%d') --config-dir benchmark/config --config-name generation --commit=${{ github.sha }} backend.model=google/gemma-2b backend.cache_implementation=null,static backend.torch_compile=false,true --multirun
|
||||
psql -f benchmark/init_db.sql
|
||||
env:
|
||||
PGDATABASE: metrics
|
||||
PGHOST: ${{ secrets.TRANSFORMERS_BENCHMARKS_PGHOST }}
|
||||
PGUSER: transformers_benchmarks
|
||||
PGPASSWORD: ${{ secrets.TRANSFORMERS_BENCHMARKS_PGPASSWORD }}
|
||||
|
||||
- name: Benchmark (merged to main event)
|
||||
if: github.event_name == 'push' && github.ref_name == 'main'
|
||||
working-directory: /transformers
|
||||
- name: Run benchmark
|
||||
run: |
|
||||
python3 -m pip install optimum-benchmark>=0.3.0
|
||||
HF_TOKEN=${{ secrets.TRANSFORMERS_BENCHMARK_TOKEN }} python3 benchmark/benchmark.py --repo_id hf-internal-testing/benchmark_results_merge_event --path_in_repo $(date +'%Y-%m-%d') --config-dir benchmark/config --config-name generation --commit=${{ github.sha }} backend.model=google/gemma-2b backend.cache_implementation=null,static backend.torch_compile=false,true --multirun
|
||||
git config --global --add safe.directory /__w/transformers/transformers
|
||||
if [ "$GITHUB_EVENT_NAME" = "pull_request" ]; then
|
||||
commit_id=$(echo "${{ github.event.pull_request.head.sha }}")
|
||||
elif [ "$GITHUB_EVENT_NAME" = "push" ]; then
|
||||
commit_id=$GITHUB_SHA
|
||||
fi
|
||||
commit_msg=$(git show -s --format=%s | cut -c1-70)
|
||||
python3 benchmark/llama.py "${{ github.head_ref || github.ref_name }}" "$commit_id" "$commit_msg"
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
|
||||
PGHOST: ${{ secrets.TRANSFORMERS_BENCHMARKS_PGHOST }}
|
||||
PGUSER: transformers_benchmarks
|
||||
PGPASSWORD: ${{ secrets.TRANSFORMERS_BENCHMARKS_PGPASSWORD }}
|
||||
|
2211
benchmark/grafana_dashboard.json
Normal file
2211
benchmark/grafana_dashboard.json
Normal file
File diff suppressed because it is too large
Load Diff
26
benchmark/init_db.sql
Normal file
26
benchmark/init_db.sql
Normal file
@ -0,0 +1,26 @@
|
||||
CREATE TABLE IF NOT EXISTS benchmarks (
|
||||
benchmark_id SERIAL PRIMARY KEY,
|
||||
branch VARCHAR(255),
|
||||
commit_id VARCHAR(72),
|
||||
commit_message VARCHAR(70),
|
||||
gpu_name VARCHAR(255),
|
||||
created_at timestamp without time zone NOT NULL DEFAULT (current_timestamp AT TIME ZONE 'UTC')
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS device_measurements (
|
||||
measurement_id SERIAL PRIMARY KEY,
|
||||
benchmark_id int REFERENCES benchmarks (benchmark_id),
|
||||
cpu_util double precision,
|
||||
mem_megabytes double precision,
|
||||
gpu_util double precision,
|
||||
gpu_mem_megabytes double precision,
|
||||
time timestamp without time zone NOT NULL DEFAULT (current_timestamp AT TIME ZONE 'UTC')
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS model_measurements (
|
||||
measurement_id SERIAL PRIMARY KEY,
|
||||
benchmark_id int REFERENCES benchmarks (benchmark_id),
|
||||
measurements jsonb,
|
||||
time timestamp without time zone NOT NULL DEFAULT (current_timestamp AT TIME ZONE 'UTC')
|
||||
);
|
||||
|
404
benchmark/llama.py
Normal file
404
benchmark/llama.py
Normal file
@ -0,0 +1,404 @@
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from statistics import mean
|
||||
from threading import Event, Thread
|
||||
from time import perf_counter, sleep
|
||||
from typing import Optional
|
||||
import gpustat
|
||||
import psutil
|
||||
import psycopg2
|
||||
import torch
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StaticCache
|
||||
from psycopg2.extras import Json
|
||||
from psycopg2.extensions import register_adapter
|
||||
|
||||
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter("[%(levelname)s - %(asctime)s] %(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "1"
|
||||
torch.set_float32_matmul_precision("high")
|
||||
register_adapter(dict, Json)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
"""
|
||||
Parse command line arguments for the benchmarking CLI.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="CLI for benchmarking the huggingface/transformers.")
|
||||
|
||||
parser.add_argument(
|
||||
"branch",
|
||||
type=str,
|
||||
help="The branch name on which the benchmarking is performed.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"commit_id",
|
||||
type=str,
|
||||
help="The commit hash on which the benchmarking is performed.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"commit_msg",
|
||||
type=str,
|
||||
help="The commit message associated with the commit, truncated to 70 characters.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args.branch, args.commit_id, args.commit_msg
|
||||
|
||||
|
||||
def collect_metrics(benchmark_id, continue_metric_collection):
|
||||
p = psutil.Process(os.getpid())
|
||||
conn = psycopg2.connect("dbname=metrics")
|
||||
cur = conn.cursor()
|
||||
while not continue_metric_collection.is_set():
|
||||
with p.oneshot():
|
||||
cpu_util = p.cpu_percent()
|
||||
mem_megabytes = p.memory_info().rss / (1024 * 1024)
|
||||
gpu_stats = gpustat.GPUStatCollection.new_query()
|
||||
gpu_util = gpu_stats[0]["utilization.gpu"]
|
||||
gpu_mem_megabytes = gpu_stats[0]["memory.used"]
|
||||
cur.execute(
|
||||
"INSERT INTO device_measurements (benchmark_id, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes) VALUES (%s, %s, %s, %s, %s)",
|
||||
(benchmark_id, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes),
|
||||
)
|
||||
sleep(0.01)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
def run_benchmark(branch: str, commit_id: str, commit_msg: str, num_tokens_to_generate=100):
|
||||
continue_metric_collection = Event()
|
||||
metrics_thread = None
|
||||
try:
|
||||
gpu_stats = gpustat.GPUStatCollection.new_query()
|
||||
gpu_name = gpu_stats[0]["name"]
|
||||
conn = psycopg2.connect("dbname=metrics")
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"INSERT INTO benchmarks (branch, commit_id, commit_message, gpu_name) VALUES (%s, %s, %s, %s) RETURNING benchmark_id",
|
||||
(branch, commit_id, commit_msg, gpu_name),
|
||||
)
|
||||
conn.commit()
|
||||
benchmark_id = cur.fetchone()[0]
|
||||
metrics_thread = Thread(target=collect_metrics, args=[benchmark_id, continue_metric_collection])
|
||||
metrics_thread.start()
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false" # silence warnings when compiling
|
||||
|
||||
device = "cuda"
|
||||
ckpt = "meta-llama/Llama-2-7b-hf"
|
||||
|
||||
# This is to avoid counting download in model load time measurement
|
||||
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16)
|
||||
gen_config = GenerationConfig(do_sample=False, top_p=1, temperature=1)
|
||||
start = perf_counter()
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
ckpt, torch_dtype=torch.float16, generation_config=gen_config
|
||||
).eval()
|
||||
model.to(device)
|
||||
torch.cuda.synchronize()
|
||||
end = perf_counter()
|
||||
model_load_time = end - start
|
||||
logger.info(f"loaded model in: {model_load_time}s")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(ckpt)
|
||||
|
||||
prompt = "Why dogs are so cute?"
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
||||
|
||||
# Specify the max length (including both the prompt and the response)
|
||||
# When calling `generate` with `cache_implementation="static" later, this is also used to create a `StaticCache` object
|
||||
# with sequence length = `max_length`. The longer the more you will re-use it
|
||||
seq_length = inputs["input_ids"].shape[1]
|
||||
model.generation_config.max_length = seq_length + num_tokens_to_generate
|
||||
batch_size = inputs["input_ids"].shape[0]
|
||||
|
||||
# Copied from the gpt-fast repo
|
||||
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
|
||||
q = torch.empty_like(probs_sort).exponential_(1)
|
||||
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
||||
|
||||
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
|
||||
logits = logits / max(temperature, 1e-5)
|
||||
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||
pivot = v.select(-1, -1).unsqueeze(-1)
|
||||
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
return probs
|
||||
|
||||
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
|
||||
probs = logits_to_probs(logits[:, -1], temperature, top_k)
|
||||
idx_next = multinomial_sample_one_no_sync(probs)
|
||||
return idx_next, probs
|
||||
|
||||
def decode_one_token(model, cur_token, cache_position, past_key_values):
|
||||
logits = model(
|
||||
cur_token,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
return_dict=False,
|
||||
use_cache=True,
|
||||
)[0]
|
||||
new_token = sample(logits, temperature=0.6, top_k=5)[0]
|
||||
return new_token
|
||||
|
||||
#########
|
||||
# Eager #
|
||||
#########
|
||||
with torch.no_grad():
|
||||
past_key_values = StaticCache(
|
||||
model.config,
|
||||
batch_size=batch_size,
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
max_cache_len=seq_length + num_tokens_to_generate,
|
||||
)
|
||||
cache_position = torch.arange(seq_length, device=device)
|
||||
start = perf_counter()
|
||||
model(
|
||||
**inputs,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
return_dict=False,
|
||||
use_cache=True,
|
||||
)
|
||||
end = perf_counter()
|
||||
first_eager_fwd_pass_time = end - start
|
||||
logger.info(f"completed first eager fwd pass in: {first_eager_fwd_pass_time}s")
|
||||
start = perf_counter()
|
||||
output = model.generate(**inputs, do_sample=False)
|
||||
end = perf_counter()
|
||||
first_eager_generate_time = end - start
|
||||
logger.info(f"completed first eager generation in: {first_eager_generate_time}s")
|
||||
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")
|
||||
|
||||
past_key_values = StaticCache(
|
||||
model.config,
|
||||
batch_size=batch_size,
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
max_cache_len=seq_length + num_tokens_to_generate,
|
||||
)
|
||||
cache_position = torch.arange(seq_length, device=device)
|
||||
start = perf_counter()
|
||||
model(
|
||||
**inputs,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
return_dict=False,
|
||||
use_cache=True,
|
||||
)
|
||||
end = perf_counter()
|
||||
second_eager_fwd_pass_time = end - start
|
||||
logger.info(f"completed second eager fwd pass in: {second_eager_fwd_pass_time}s")
|
||||
start = perf_counter()
|
||||
model.generate(**inputs, do_sample=False)
|
||||
end = perf_counter()
|
||||
second_eager_generate_time = end - start
|
||||
logger.info(f"completed second eager generation in: {second_eager_generate_time}s")
|
||||
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")
|
||||
|
||||
torch.compiler.reset()
|
||||
|
||||
################
|
||||
# Forward pass #
|
||||
################
|
||||
|
||||
# `torch.compile(model, ...)` is not recommended as you compile callbacks
|
||||
# and full generate. We recommend compiling only the forward for now.
|
||||
# "reduce-overhead" will use cudagraphs.
|
||||
generated_ids = torch.zeros(
|
||||
(batch_size, num_tokens_to_generate + seq_length), dtype=torch.int, device=device
|
||||
)
|
||||
|
||||
generated_ids[:, :seq_length] = inputs["input_ids"]
|
||||
decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True)
|
||||
# model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
|
||||
# TODO use decode_one_token(model, input_id.clone(), cache_position) for verification
|
||||
past_key_values = StaticCache(
|
||||
model.config,
|
||||
batch_size=batch_size,
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
max_cache_len=seq_length + num_tokens_to_generate + 10,
|
||||
)
|
||||
cache_position = torch.arange(seq_length, device=device)
|
||||
all_generated_tokens = []
|
||||
### First compile, prefill
|
||||
start = perf_counter()
|
||||
next_token = decode_one_token(
|
||||
model, inputs["input_ids"], cache_position=cache_position, past_key_values=past_key_values
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
end = perf_counter()
|
||||
time_to_first_token = end - start
|
||||
logger.info(f"completed first compile generation in: {time_to_first_token}s")
|
||||
cache_position += 1
|
||||
all_generated_tokens += next_token.clone().detach().cpu().tolist()
|
||||
|
||||
cache_position = torch.tensor([seq_length], device=device)
|
||||
### First compile, decoding
|
||||
start = perf_counter()
|
||||
next_token = decode_one_token(
|
||||
model, next_token.clone(), cache_position=cache_position, past_key_values=past_key_values
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
end = perf_counter()
|
||||
time_to_second_token = end - start
|
||||
logger.info(f"completed second compile generation in: {time_to_first_token}s")
|
||||
cache_position += 1
|
||||
all_generated_tokens += next_token.clone().detach().cpu().tolist()
|
||||
|
||||
### Second compile, decoding
|
||||
start = perf_counter()
|
||||
next_token = decode_one_token(
|
||||
model, next_token.clone(), cache_position=cache_position, past_key_values=past_key_values
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
end = perf_counter()
|
||||
time_to_third_token = end - start
|
||||
logger.info(f"completed third compile forward in: {time_to_first_token}s")
|
||||
cache_position += 1
|
||||
all_generated_tokens += next_token.clone().detach().cpu().tolist()
|
||||
|
||||
### Using cuda graphs decoding
|
||||
|
||||
start = perf_counter()
|
||||
for _ in range(1, num_tokens_to_generate):
|
||||
all_generated_tokens += next_token.clone().detach().cpu().tolist()
|
||||
next_token = decode_one_token(
|
||||
model, next_token.clone(), cache_position=cache_position, past_key_values=past_key_values
|
||||
)
|
||||
cache_position += 1
|
||||
torch.cuda.synchronize()
|
||||
end = perf_counter()
|
||||
mean_time_to_next_token = (end - start) / num_tokens_to_generate
|
||||
logger.info(f"completed next compile generation in: {mean_time_to_next_token}s")
|
||||
logger.info(f"generated: {tokenizer.batch_decode(all_generated_tokens)}")
|
||||
|
||||
####################
|
||||
# Generate compile #
|
||||
####################
|
||||
torch.compiler.reset()
|
||||
# we will not compile full generate as it' s to intensive, tho we measure full forward!
|
||||
|
||||
past_key_values = StaticCache(
|
||||
model.config,
|
||||
batch_size=batch_size,
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
max_cache_len=seq_length + 128,
|
||||
)
|
||||
|
||||
# 1st call
|
||||
start = perf_counter()
|
||||
output = model.generate(**inputs, past_key_values=past_key_values)
|
||||
torch.cuda.synchronize()
|
||||
end = perf_counter()
|
||||
first_compile_generate_time = end - start
|
||||
logger.info(f"completed first compile generation in: {first_compile_generate_time}s")
|
||||
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")
|
||||
|
||||
past_key_values = StaticCache(
|
||||
model.config,
|
||||
batch_size=batch_size,
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
max_cache_len=seq_length + 128,
|
||||
)
|
||||
# 2nd call
|
||||
start = perf_counter()
|
||||
output = model.generate(**inputs, past_key_values=past_key_values)
|
||||
torch.cuda.synchronize()
|
||||
end = perf_counter()
|
||||
second_compile_generate_time = end - start
|
||||
logger.info(f"completed second compile generation in: {second_compile_generate_time}s")
|
||||
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")
|
||||
|
||||
past_key_values = StaticCache(
|
||||
model.config,
|
||||
batch_size=batch_size,
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
max_cache_len=seq_length + 128,
|
||||
)
|
||||
|
||||
# 3nd call
|
||||
start = perf_counter()
|
||||
output = model.generate(**inputs, past_key_values=past_key_values)
|
||||
end = perf_counter()
|
||||
third_compile_generate_time = end - start
|
||||
logger.info(f"completed second compile generation in: {third_compile_generate_time}s")
|
||||
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")
|
||||
|
||||
past_key_values = StaticCache(
|
||||
model.config,
|
||||
batch_size=batch_size,
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
max_cache_len=seq_length + 128,
|
||||
)
|
||||
# 4th call
|
||||
start = perf_counter()
|
||||
output = model.generate(**inputs, past_key_values=past_key_values)
|
||||
end = perf_counter()
|
||||
fourth_compile_generate_time = end - start
|
||||
logger.info(f"completed second compile generation in: {fourth_compile_generate_time}s")
|
||||
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")
|
||||
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO model_measurements (
|
||||
benchmark_id,
|
||||
measurements
|
||||
) VALUES (%s, %s)
|
||||
""",
|
||||
(
|
||||
benchmark_id,
|
||||
{
|
||||
"model_load_time": model_load_time,
|
||||
"first_eager_forward_pass_time_secs": first_eager_fwd_pass_time,
|
||||
"second_eager_forward_pass_time_secs": second_eager_fwd_pass_time,
|
||||
"first_eager_generate_time_secs": first_eager_generate_time,
|
||||
"second_eager_generate_time_secs": second_eager_generate_time,
|
||||
"time_to_first_token_secs": time_to_first_token,
|
||||
"time_to_second_token_secs": time_to_second_token,
|
||||
"time_to_third_token_secs": time_to_third_token,
|
||||
"time_to_next_token_mean_secs": mean_time_to_next_token,
|
||||
"first_compile_generate_time_secs": first_compile_generate_time,
|
||||
"second_compile_generate_time_secs": second_compile_generate_time,
|
||||
"third_compile_generate_time_secs": third_compile_generate_time,
|
||||
"fourth_compile_generate_time_secs": fourth_compile_generate_time,
|
||||
},
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Caught exception: {e}")
|
||||
continue_metric_collection.set()
|
||||
if metrics_thread is not None:
|
||||
metrics_thread.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
branch, commit_id, commit_msg = parse_arguments()
|
||||
run_benchmark(branch, commit_id, commit_msg, num_tokens_to_generate=20)
|
5
benchmark/requirements.txt
Normal file
5
benchmark/requirements.txt
Normal file
@ -0,0 +1,5 @@
|
||||
gpustat==1.1.1
|
||||
psutil==6.0.0
|
||||
psycopg2==2.9.9
|
||||
torch>=2.4.0
|
||||
hf_transfer
|
Loading…
Reference in New Issue
Block a user