transformers/examples/research_projects/rag/finetune_rag_ray.sh
Amog Kamsetty a4b21cdd20
[RAG] Add Ray implementation for distributed retrieval (#9197)
* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* uncomment

* uncomment

* wip

* updates

* add docstring

* updates

* fix arg

* fixes

* add unit tests

* update readme

* update readme

* update finetune script

* update test

* add test

* add ray to test dependencies

* separate ray and ray tune

* formatting

* shutdown ray at end of test

* fix tests

* formatting

* formatting

* even more formatting

* address comments

* formatting

* add files

* Update examples/research_projects/rag/test_distributed_retriever.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* address comments

* addressing comments

Co-authored-by: Ubuntu <ubuntu@ip-172-31-21-208.us-west-2.compute.internal>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
2020-12-21 10:39:30 +01:00

45 lines
1.2 KiB
Bash
Executable File

# Sample script to finetune RAG using Ray for distributed retrieval.
# Add parent directory to python path to access lightning_base.py
export PYTHONPATH="../":"${PYTHONPATH}"
# Start a single-node Ray cluster.
ray start --head
# A sample finetuning run, you need to specify data_dir, output_dir and model_name_or_path
# run ./examples/rag/finetune_rag_ray.sh --help to see all the possible options
python examples/rag/finetune_rag.py \
--data_dir $DATA_DIR \
--output_dir $OUTPUT_DIR \
--model_name_or_path $MODEL_NAME_OR_PATH \
--model_type rag_sequence \
--fp16 \
--gpus 8 \
--profile \
--do_train \
--do_predict \
--n_val -1 \
--train_batch_size 8 \
--eval_batch_size 1 \
--max_source_length 128 \
--max_target_length 25 \
--val_max_target_length 25 \
--test_max_target_length 25 \
--label_smoothing 0.1 \
--dropout 0.1 \
--attention_dropout 0.1 \
--weight_decay 0.001 \
--adam_epsilon 1e-08 \
--max_grad_norm 0.1 \
--lr_scheduler polynomial \
--learning_rate 3e-05 \
--num_train_epochs 100 \
--warmup_steps 500 \
--gradient_accumulation_steps 1 \
--distributed_retriever ray \
--num_retrieval_workers 4
# Stop the Ray cluster.
ray stop