transformers/examples/research_projects/codeparrot/scripts/initialize_model.py
Loubna Ben Allal b48ac1a094
Fix CodeParrot training script (#17291)
* average loss over batches and accumulated steps for tracking

* fix layernorm weight decay

* use AdamW from Pytorch instead of Transformers

* add shuffling of sequences inside the batches

* add shuffling of sequences inside the batches

* add logging dir and reformat code

* fix lr tracking

* remove Mistral scaling

* keep Mistral scaling

* reformat code

* fix error

* fix error

* use shuffling function from Pytorch

* remove argument for shuffling batch sequences as it isn't optional

* update package versions and install accelerate from source

* remove unused package

* Update loss average over accumulated steps

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* Update loss average over accumulated steps

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* use one shuffle buffer argument

* compute avg_loss in one line

Co-authored-by: Loubna ben allal <loubnabenallal@gmail.com>
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
2022-05-23 12:55:35 +02:00

27 lines
892 B
Python

from arguments import InitializationArguments
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
# Configuration
parser = HfArgumentParser(InitializationArguments)
args = parser.parse_args()
# Load codeparrot tokenizer trained for Python code tokenization
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)
# Config: "scale_attn_by_layer_idx" and "reorder_and_upcast_attn" are Mistral stability tweaks
config_kwargs = {
"vocab_size": len(tokenizer),
"scale_attn_by_inverse_layer_idx": True,
"reorder_and_upcast_attn": True,
}
# Load model config (GPT-2 large in this case)
config = AutoConfig.from_pretrained(args.config_name, **config_kwargs)
# Initialize new model with config
model = AutoModelForCausalLM.from_config(config)
# Save model to the hub
model.save_pretrained(args.model_name, push_to_hub=args.push_to_hub)