Code parrot minor fixes/niceties (#14666)

* Add some nicety flags for better controlling evaluation.

* Fix dependency issue with outdated requirement

* Add additional flag to example to ensure eval is done

* Wrap code into main function for accelerate launcher to find

* Fix valid batch size flag in readme

* Add note to install git-lfs when initializing/training the model

* Update examples/research_projects/codeparrot/scripts/arguments.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* Update examples/research_projects/codeparrot/README.md

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* Revert "Wrap code into main function for accelerate launcher to find"

This reverts commit ff11df1c81.

* Fix formatting issue

* Move git-lfs instructions to installation section

* Add a quick check before code generation for code evaluation

* Fix styling issue

* Update examples/research_projects/codeparrot/scripts/human_eval.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* Make iterable dataset use passed in tokenizer rather than globally defined one

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
Co-authored-by: ncoop57 <nac33@students.uwf.edu>
This commit is contained in:
Nathan Cooper 2021-12-13 03:30:50 -05:00 committed by GitHub
parent 91f3dfbfdd
commit 48bf7e47a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 28 additions and 6 deletions

View File

@ -31,6 +31,8 @@ Before you run any of the scripts make sure you are logged in and can push to th
huggingface-cli login
```
Additionally, sure you have git-lfs installed. You can find instructions for how to install it [here](https://git-lfs.github.com/).
## Dataset
The source of the dataset is the GitHub dump available on Google's [BigQuery](https://cloud.google.com/blog/topics/public-datasets/github-on-bigquery-analyze-all-the-open-source-code). The database was queried for all Python files with less than 1MB in size resulting in a 180GB dataset with over 20M files. The dataset is available on the Hugging Face Hub [here](https://huggingface.co/datasets/transformersbook/codeparrot).
@ -96,7 +98,7 @@ If you want to train the small model you need to make some modifications:
accelerate launch scripts/codeparrot_training.py \
--model_ckpt lvwerra/codeparrot-small \
--train_batch_size 12 \
--eval_batch_size 12 \
--valid_batch_size 12 \
--learning_rate 5e-4 \
--num_warmup_steps 2000 \
--gradient_accumulation 1 \
@ -125,7 +127,8 @@ python scripts/human_eval.py --model_ckpt lvwerra/codeparrot \
--do_sample True \
--temperature 0.2 \
--top_p 0.95 \
--n_samples=200
--n_samples=200 \
--HF_ALLOW_CODE_EVAL="0"
```
The results as well as reference values are shown in the following table:

View File

@ -4,4 +4,4 @@ accelerate==0.5.1
wandb==0.12.0
tensorboard==2.6.0
torch==1.9.0
huggingface-hub==0.0.19
huggingface-hub==0.1.0

View File

@ -83,6 +83,10 @@ class HumanEvalArguments:
metadata={"help": "Model name or path of model to be evaluated."},
)
num_workers: Optional[int] = field(default=None, metadata={"help": "Number of workers used for code evaluation."})
num_tasks: Optional[int] = field(
default=None,
metadata={"help": "The number of human-eval tasks to run. If not included all tasks are evaluated."},
)
do_sample: Optional[bool] = field(
default=True, metadata={"help": "Sample from the language model's output distribution."}
)
@ -101,6 +105,12 @@ class HumanEvalArguments:
HF_ALLOW_CODE_EVAL: Optional[str] = field(
default="0", metadata={"help": "Allow `code_eval` to execute Python code on machine"}
)
device_int: Optional[int] = field(
default=-1,
metadata={
"help": "Determine which device to run the `text-generation` Pipeline on. -1 is CPU and any zero or positive number corresponds to which GPU device id to run on."
},
)
@dataclass

View File

@ -59,7 +59,7 @@ class ConstantLengthDataset(IterableDataset):
else:
more_examples = False
break
tokenized_inputs = tokenizer(buffer, truncation=False)["input_ids"]
tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
all_token_ids = []
for tokenized_input in tokenized_inputs:
all_token_ids.extend(tokenized_input + [self.concat_token_id])

View File

@ -51,14 +51,23 @@ def main():
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model_ckpt)
model = AutoModelForCausalLM.from_pretrained(args.model_ckpt)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=-1)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=args.device_int)
# Load evaluation dataset and metric
human_eval = load_dataset("openai_humaneval")
code_eval_metric = load_metric("code_eval")
# Run a quick test to see if code evaluation is enabled
try:
_ = code_eval_metric.compute(references=[""], predictions=[[""]])
except ValueError as exception:
print(
'Code evaluation not enabled. Read the warning below carefully and then use `--HF_ALLOW_CODE_EVAL="1"` flag to enable code evaluation.'
)
raise exception
# Generate completions for evaluation set
n_tasks = 4 # len(human_eval["test"])
n_tasks = args.num_tasks if args.num_tasks is not None else len(human_eval["test"])
generations, references = [], []
for task in tqdm(range(n_tasks)):
task_generations = []