transformers/examples/research_projects/codeparrot/scripts/human_eval.py
Nathan Cooper 48bf7e47a0
Code parrot minor fixes/niceties (#14666)
* Add some nicety flags for better controlling evaluation.

* Fix dependency issue with outdated requirement

* Add additional flag to example to ensure eval is done

* Wrap code into main function for accelerate launcher to find

* Fix valid batch size flag in readme

* Add note to install git-lfs when initializing/training the model

* Update examples/research_projects/codeparrot/scripts/arguments.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* Update examples/research_projects/codeparrot/README.md

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* Revert "Wrap code into main function for accelerate launcher to find"

This reverts commit ff11df1c81.

* Fix formatting issue

* Move git-lfs instructions to installation section

* Add a quick check before code generation for code evaluation

* Fix styling issue

* Update examples/research_projects/codeparrot/scripts/human_eval.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* Make iterable dataset use passed in tokenizer rather than globally defined one

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
Co-authored-by: ncoop57 <nac33@students.uwf.edu>
2021-12-13 09:30:50 +01:00

97 lines
3.6 KiB
Python

import json
import multiprocessing
import os
import re
from datasets import load_dataset, load_metric
from tqdm import tqdm
import transformers
from arguments import HumanEvalArguments
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, pipeline, set_seed
def first_block(string):
"""Split off first block of code by scanning for class, def etc. on newlines."""
return re.split("\nclass|\ndef|\n#|\n@|\nprint|\nif", string)[0].rstrip()
def complete_code(pipe, prompt, num_completions=1, **gen_kwargs):
"""Complete prompt with text generation pipeline and return num_completions."""
prompt = pipe.tokenizer.eos_token + prompt
code_gens = pipe(prompt, num_return_sequences=num_completions, **gen_kwargs)
return [first_block(code_gen["generated_text"][len(prompt) :]) for code_gen in code_gens]
def main():
# Setup configuration
parser = HfArgumentParser(HumanEvalArguments)
args = parser.parse_args()
transformers.logging.set_verbosity_error()
# enables code execution in code_eval metric
os.environ["HF_ALLOW_CODE_EVAL"] = args.HF_ALLOW_CODE_EVAL
# make sure tokenizer plays nice with multiprocessing
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if args.num_workers is None:
args.num_workers = multiprocessing.cpu_count()
set_seed(args.seed)
# Generation settings
gen_kwargs = {
"do_sample": args.do_sample,
"temperature": args.temperature,
"max_new_tokens": args.max_new_tokens,
"top_p": args.top_p,
"top_k": args.top_k,
}
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model_ckpt)
model = AutoModelForCausalLM.from_pretrained(args.model_ckpt)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=args.device_int)
# Load evaluation dataset and metric
human_eval = load_dataset("openai_humaneval")
code_eval_metric = load_metric("code_eval")
# Run a quick test to see if code evaluation is enabled
try:
_ = code_eval_metric.compute(references=[""], predictions=[[""]])
except ValueError as exception:
print(
'Code evaluation not enabled. Read the warning below carefully and then use `--HF_ALLOW_CODE_EVAL="1"` flag to enable code evaluation.'
)
raise exception
# Generate completions for evaluation set
n_tasks = args.num_tasks if args.num_tasks is not None else len(human_eval["test"])
generations, references = [], []
for task in tqdm(range(n_tasks)):
task_generations = []
prompt = human_eval["test"][task]["prompt"].strip()
for batch in range(args.n_samples // args.batch_size):
task_generations.extend(complete_code(pipe, prompt, num_completions=args.batch_size, **gen_kwargs))
generations.append([prompt + gen for gen in task_generations])
test_func = human_eval["test"][task]["test"]
entry_point = f"check({human_eval['test'][task]['entry_point']})"
references.append("\n" + test_func + "\n" + entry_point)
# Evaluate completions with "code_eval" metric
pass_at_k, _ = code_eval_metric.compute(
references=references, predictions=generations, num_workers=args.num_workers
)
print(f"Results: {pass_at_k}")
# Save results to json file
with open(args.output_file, "w") as fp:
json.dump(pass_at_k, fp)
# For some reason the folliwng seems to be necessary sometimes for code_eval to work nice with multiprocessing
# https://stackoverflow.com/questions/60804599/python-multiprocessing-keeps-spawning-the-whole-script
if __name__ == "__main__":
main()