import json import multiprocessing import os import re from datasets import load_dataset, load_metric from tqdm import tqdm import transformers from arguments import HumanEvalArguments from transformers import ( AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, StoppingCriteria, StoppingCriteriaList, pipeline, set_seed, ) EOF_STRINGS = ["\nclass", "\ndef", "\n#", "\n@", "\nprint", "\nif"] class EndOfFunctionCriteria(StoppingCriteria): """Custom `StoppingCriteria` which checks if all generated functions in the batch are completed.""" def __init__(self, start_length, eof_strings, tokenizer): self.start_length = start_length self.eof_strings = eof_strings self.tokenizer = tokenizer def __call__(self, input_ids, scores, **kwargs): """Returns true if all generated sequences contain any of the end-of-function strings.""" decoded_generations = self.tokenizer.batch_decode(input_ids[:, self.start_length :]) done = [] for decoded_generation in decoded_generations: done.append(any([stop_string in decoded_generation for stop_string in self.eof_strings])) return all(done) def first_block(string): """Split off first block of code by scanning for class, def etc. on newlines.""" return re.split("|".join(EOF_STRINGS), string)[0].rstrip() def complete_code(pipe, prompt, num_completions=1, **gen_kwargs): """Complete prompt with text generation pipeline and return num_completions.""" prompt = pipe.tokenizer.eos_token + prompt code_gens = pipe(prompt, num_return_sequences=num_completions, **gen_kwargs) return [first_block(code_gen["generated_text"][len(prompt) :]) for code_gen in code_gens] def main(): # Setup configuration parser = HfArgumentParser(HumanEvalArguments) args = parser.parse_args() transformers.logging.set_verbosity_error() # enables code execution in code_eval metric os.environ["HF_ALLOW_CODE_EVAL"] = args.HF_ALLOW_CODE_EVAL # make sure tokenizer plays nice with multiprocessing os.environ["TOKENIZERS_PARALLELISM"] = "false" if args.num_workers is None: args.num_workers = multiprocessing.cpu_count() set_seed(args.seed) # Load model and tokenizer tokenizer = AutoTokenizer.from_pretrained(args.model_ckpt) model = AutoModelForCausalLM.from_pretrained(args.model_ckpt) pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=args.device_int) # Generation settings gen_kwargs = { "do_sample": args.do_sample, "temperature": args.temperature, "max_new_tokens": args.max_new_tokens, "top_p": args.top_p, "top_k": args.top_k, "stopping_criteria": StoppingCriteriaList([EndOfFunctionCriteria(0, EOF_STRINGS, tokenizer)]), } # Load evaluation dataset and metric human_eval = load_dataset("openai_humaneval") code_eval_metric = load_metric("code_eval") # Run a quick test to see if code evaluation is enabled try: _ = code_eval_metric.compute(references=[""], predictions=[[""]]) except ValueError as exception: print( 'Code evaluation not enabled. Read the warning below carefully and then use `--HF_ALLOW_CODE_EVAL="1"` flag to enable code evaluation.' ) raise exception # Generate completions for evaluation set n_tasks = args.num_tasks if args.num_tasks is not None else len(human_eval["test"]) generations, references = [], [] for task in tqdm(range(n_tasks)): task_generations = [] prompt = human_eval["test"][task]["prompt"].strip() gen_kwargs["stopping_criteria"][0].start_length = len(tokenizer(prompt)["input_ids"]) for batch in range(args.n_samples // args.batch_size): task_generations.extend(complete_code(pipe, prompt, num_completions=args.batch_size, **gen_kwargs)) generations.append([prompt + gen for gen in task_generations]) test_func = human_eval["test"][task]["test"] entry_point = f"check({human_eval['test'][task]['entry_point']})" references.append("\n" + test_func + "\n" + entry_point) # Evaluate completions with "code_eval" metric pass_at_k, _ = code_eval_metric.compute( references=references, predictions=generations, num_workers=args.num_workers ) print(f"Results: {pass_at_k}") # Save results to json file with open(args.output_file, "w") as fp: json.dump(pass_at_k, fp) # For some reason the folliwng seems to be necessary sometimes for code_eval to work nice with multiprocessing # https://stackoverflow.com/questions/60804599/python-multiprocessing-keeps-spawning-the-whole-script if __name__ == "__main__": main()