mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
feat: add repository
field to benchmarks table (#38582)
* feat: add `repository` field to benchmarks table * fix: remove unwanted `,`
This commit is contained in:
parent
1285aec4cc
commit
ae3733f06e
2
.github/workflows/benchmark.yml
vendored
2
.github/workflows/benchmark.yml
vendored
@ -64,7 +64,7 @@ jobs:
|
|||||||
commit_id=$GITHUB_SHA
|
commit_id=$GITHUB_SHA
|
||||||
fi
|
fi
|
||||||
commit_msg=$(git show -s --format=%s | cut -c1-70)
|
commit_msg=$(git show -s --format=%s | cut -c1-70)
|
||||||
python3 benchmark/benchmarks_entrypoint.py "$BRANCH_NAME" "$commit_id" "$commit_msg"
|
python3 benchmark/benchmarks_entrypoint.py "huggingface/transformers" "$BRANCH_NAME" "$commit_id" "$commit_msg"
|
||||||
env:
|
env:
|
||||||
HF_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
|
HF_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
|
||||||
# Enable this to see debug logs
|
# Enable this to see debug logs
|
||||||
|
@ -2,11 +2,11 @@ import argparse
|
|||||||
import importlib.util
|
import importlib.util
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Dict
|
|
||||||
import sys
|
import sys
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
from psycopg2.extras import Json
|
|
||||||
from psycopg2.extensions import register_adapter
|
from psycopg2.extensions import register_adapter
|
||||||
|
from psycopg2.extras import Json
|
||||||
|
|
||||||
|
|
||||||
register_adapter(dict, Json)
|
register_adapter(dict, Json)
|
||||||
@ -17,10 +17,13 @@ class ImportModuleException(Exception):
|
|||||||
|
|
||||||
|
|
||||||
class MetricsRecorder:
|
class MetricsRecorder:
|
||||||
def __init__(self, connection, logger: logging.Logger, branch: str, commit_id: str, commit_msg: str):
|
def __init__(
|
||||||
|
self, connection, logger: logging.Logger, repository: str, branch: str, commit_id: str, commit_msg: str
|
||||||
|
):
|
||||||
self.conn = connection
|
self.conn = connection
|
||||||
self.conn.autocommit = True
|
self.conn.autocommit = True
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
self.repository = repository
|
||||||
self.branch = branch
|
self.branch = branch
|
||||||
self.commit_id = commit_id
|
self.commit_id = commit_id
|
||||||
self.commit_msg = commit_msg
|
self.commit_msg = commit_msg
|
||||||
@ -32,8 +35,8 @@ class MetricsRecorder:
|
|||||||
# gpu_name: str, model_id: str
|
# gpu_name: str, model_id: str
|
||||||
with self.conn.cursor() as cur:
|
with self.conn.cursor() as cur:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"INSERT INTO benchmarks (branch, commit_id, commit_message, metadata) VALUES (%s, %s, %s, %s) RETURNING benchmark_id",
|
"INSERT INTO benchmarks (repository, branch, commit_id, commit_message, metadata) VALUES (%s, %s, %s, %s, %s) RETURNING benchmark_id",
|
||||||
(self.branch, self.commit_id, self.commit_msg, metadata),
|
(self.repository, self.branch, self.commit_id, self.commit_msg, metadata),
|
||||||
)
|
)
|
||||||
benchmark_id = cur.fetchone()[0]
|
benchmark_id = cur.fetchone()[0]
|
||||||
logger.debug(f"initialised benchmark #{benchmark_id}")
|
logger.debug(f"initialised benchmark #{benchmark_id}")
|
||||||
@ -82,12 +85,18 @@ handler.setFormatter(formatter)
|
|||||||
logger.addHandler(handler)
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
|
||||||
def parse_arguments():
|
def parse_arguments() -> Tuple[str, str, str, str]:
|
||||||
"""
|
"""
|
||||||
Parse command line arguments for the benchmarking CLI.
|
Parse command line arguments for the benchmarking CLI.
|
||||||
"""
|
"""
|
||||||
parser = argparse.ArgumentParser(description="CLI for benchmarking the huggingface/transformers.")
|
parser = argparse.ArgumentParser(description="CLI for benchmarking the huggingface/transformers.")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"repository",
|
||||||
|
type=str,
|
||||||
|
help="The repository name on which the benchmarking is performed.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"branch",
|
"branch",
|
||||||
type=str,
|
type=str,
|
||||||
@ -108,7 +117,7 @@ def parse_arguments():
|
|||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return args.branch, args.commit_id, args.commit_msg
|
return args.repository, args.branch, args.commit_id, args.commit_msg
|
||||||
|
|
||||||
|
|
||||||
def import_from_path(module_name, file_path):
|
def import_from_path(module_name, file_path):
|
||||||
@ -125,7 +134,7 @@ def import_from_path(module_name, file_path):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
benchmarks_folder_path = os.path.dirname(os.path.realpath(__file__))
|
benchmarks_folder_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
|
||||||
branch, commit_id, commit_msg = parse_arguments()
|
repository, branch, commit_id, commit_msg = parse_arguments()
|
||||||
|
|
||||||
for entry in os.scandir(benchmarks_folder_path):
|
for entry in os.scandir(benchmarks_folder_path):
|
||||||
try:
|
try:
|
||||||
@ -136,7 +145,7 @@ if __name__ == "__main__":
|
|||||||
logger.debug(f"loading: {entry.name}")
|
logger.debug(f"loading: {entry.name}")
|
||||||
module = import_from_path(entry.name.split(".")[0], entry.path)
|
module = import_from_path(entry.name.split(".")[0], entry.path)
|
||||||
logger.info(f"running benchmarks in: {entry.name}")
|
logger.info(f"running benchmarks in: {entry.name}")
|
||||||
module.run_benchmark(logger, branch, commit_id, commit_msg)
|
module.run_benchmark(logger, repository, branch, commit_id, commit_msg)
|
||||||
except ImportModuleException as e:
|
except ImportModuleException as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
CREATE TABLE IF NOT EXISTS benchmarks (
|
CREATE TABLE IF NOT EXISTS benchmarks (
|
||||||
benchmark_id SERIAL PRIMARY KEY,
|
benchmark_id SERIAL PRIMARY KEY,
|
||||||
|
repository VARCHAR(255),
|
||||||
branch VARCHAR(255),
|
branch VARCHAR(255),
|
||||||
commit_id VARCHAR(72),
|
commit_id VARCHAR(72),
|
||||||
commit_message VARCHAR(70),
|
commit_message VARCHAR(70),
|
||||||
|
@ -33,11 +33,15 @@ def collect_metrics(benchmark_id, continue_metric_collection, metrics_recorder):
|
|||||||
sleep(0.01)
|
sleep(0.01)
|
||||||
|
|
||||||
|
|
||||||
def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str, num_tokens_to_generate=100):
|
def run_benchmark(
|
||||||
|
logger: Logger, repository: str, branch: str, commit_id: str, commit_msg: str, num_tokens_to_generate=100
|
||||||
|
):
|
||||||
continue_metric_collection = Event()
|
continue_metric_collection = Event()
|
||||||
metrics_thread = None
|
metrics_thread = None
|
||||||
model_id = "meta-llama/Llama-2-7b-hf"
|
model_id = "meta-llama/Llama-2-7b-hf"
|
||||||
metrics_recorder = MetricsRecorder(psycopg2.connect("dbname=metrics"), logger, branch, commit_id, commit_msg)
|
metrics_recorder = MetricsRecorder(
|
||||||
|
psycopg2.connect("dbname=metrics"), logger, repository, branch, commit_id, commit_msg
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
gpu_stats = gpustat.GPUStatCollection.new_query()
|
gpu_stats = gpustat.GPUStatCollection.new_query()
|
||||||
gpu_name = gpu_stats[0]["name"]
|
gpu_name = gpu_stats[0]["name"]
|
||||||
|
Loading…
Reference in New Issue
Block a user