From c989ddd29450ea691098f0ce97d2465e52004f7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Fri, 13 Jun 2025 14:03:49 +0200 Subject: [PATCH] Simplify and update trl examples (#38772) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Simplify and update trl examples * Remove optim_args from SFTConfig in Trainer documentation * Update docs/source/en/trainer.md * Apply suggestions from code review * Update docs/source/en/trainer.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --------- Co-authored-by: Quentin Gallouédec Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/ar/trainer.md | 165 ++++++++--------------------- docs/source/en/model_doc/mamba.md | 47 +++----- docs/source/en/model_doc/mamba2.md | 35 ++---- docs/source/en/trainer.md | 42 +++----- docs/source/ko/model_doc/mamba.md | 28 ++--- docs/source/ko/model_doc/mamba2.md | 35 ++---- docs/source/ko/trainer.md | 109 ++++--------------- 7 files changed, 114 insertions(+), 347 deletions(-) diff --git a/docs/source/ar/trainer.md b/docs/source/ar/trainer.md index 7bdebdca2a2..1784d76a4ec 100644 --- a/docs/source/ar/trainer.md +++ b/docs/source/ar/trainer.md @@ -306,75 +306,45 @@ pip install galore-torch ثم أضف ببساطة أحد `["galore_adamw"، "galore_adafactor"، "galore_adamw_8bit"]` في `optim` جنبًا إلى جنب مع `optim_target_modules`، والتي يمكن أن تكون قائمة من السلاسل أو التعبيرات النمطية regex أو المسار الكامل المطابق لأسماء الوحدات المستهدفة التي تريد تكييفها. فيما يلي مثال على النص البرمجي كامل(تأكد من `pip install trl datasets`): ```python -import torch import datasets -import trl - -from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM +from trl import SFTConfig, SFTTrainer train_dataset = datasets.load_dataset('imdb', split='train') - -args = TrainingArguments( - output_dir="./test-galore"، +args = SFTConfig( + output_dir="./test-galore", max_steps=100, - per_device_train_batch_size=2, - optim="galore_adamw"، - optim_target_modules=[r".*.attn.*"، r".*.mlp.*"] + optim="galore_adamw", + optim_target_modules=[r".*.attn.*", r".*.mlp.*"], + gradient_checkpointing=True, ) - -model_id = "google/gemma-2b" - -config = AutoConfig.from_pretrained(model_id) - -tokenizer = AutoTokenizer.from_pretrained(model_id) -model = AutoModelForCausalLM.from_config(config).to(0) - -trainer = trl.SFTTrainer( - model=model, +trainer = SFTTrainer( + model="google/gemma-2b", args=args, train_dataset=train_dataset, - dataset_text_field='text', - max_seq_length=512, ) - trainer.train() ``` لتمرير معامﻻت إضافية يدعمها GaLore، يجب عليك تمرير `optim_args` بشكل صحيح، على سبيل المثال: ```python -import torch import datasets -import trl - -from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM +from trl import SFTConfig, SFTTrainer train_dataset = datasets.load_dataset('imdb', split='train') - -args = TrainingArguments( +args = SFTConfig( output_dir="./test-galore", max_steps=100, - per_device_train_batch_size=2, optim="galore_adamw", optim_target_modules=[r".*.attn.*", r".*.mlp.*"], optim_args="rank=64, update_proj_gap=100, scale=0.10", + gradient_checkpointing=True, ) - -model_id = "google/gemma-2b" - -config = AutoConfig.from_pretrained(model_id) - -tokenizer = AutoTokenizer.from_pretrained(model_id) -model = AutoModelForCausalLM.from_config(config).to(0) - -trainer = trl.SFTTrainer( - model=model, +trainer = SFTTrainer( + model="google/gemma-2b", args=args, train_dataset=train_dataset, - dataset_text_field='text', - max_seq_length=512, ) - trainer.train() ``` يمكنك قراءة المزيد حول الطريقة في [المستودع الأصلي](https://github.com/jiaweizzhao/GaLore) أو [الورقة البحثية](https://huggingface.co/papers/2403.03507). @@ -386,37 +356,22 @@ trainer.train() يمكنك أيضًا إجراء تحسين طبقة تلو الأخرى عن طريق إضافة `layerwise` إلى اسم المُحسِّن كما هو موضح أدناه: ```python -import torch import datasets -import trl +from trl import SFTConfig, SFTTrainer -from transformers import TrainingArguments، AutoConfig، AutoTokenizer، AutoModelForCausalLM - -train_dataset = datasets.load_dataset('imdb'، split='train') - -args = TrainingArguments( - output_dir="./test-galore"، - max_steps=100، - per_device_train_batch_size=2، - optim="galore_adamw_layerwise"، - optim_target_modules=[r".*.attn.*"، r".*.mlp.*"] +train_dataset = datasets.load_dataset('imdb', split='train') +args = SFTConfig( + output_dir="./test-galore", + max_steps=100, + optim="galore_adamw_layerwise", + optim_target_modules=[r".*.attn.*", r".*.mlp.*"], + gradient_checkpointing=True, ) - -model_id = "google/gemma-2b" - -config = AutoConfig.from_pretrained(model_id) - -tokenizer = AutoTokenizer.from_pretrained(model_id) -model = AutoModelForCausalLM.from_config(config).to(0) - -trainer = trl.SFTTrainer( - model=model، - args=args، - train_dataset=train_dataset، - dataset_text_field='text'، - max_seq_length=512، +trainer = SFTTrainer( + model="google/gemma-2b", + args=args, + train_dataset=train_dataset, ) - trainer.train() ``` @@ -436,39 +391,21 @@ trainer.train() فيما يلي نص برمجي بسيط يوضح كيفية ضبط نموذج [google/gemma-2b](https://huggingface.co/google/gemma-2b) على مجموعة بيانات IMDB في الدقة الكاملة: ```python -import torch import datasets -from transformers import TrainingArguments، AutoTokenizer، AutoModelForCausalLM -import trl +from trl import SFTConfig, SFTTrainer -train_dataset = datasets.load_dataset('imdb'، split='train') - -args = TrainingArguments( - output_dir="./test-lomo"، - max_steps=100، - per_device_train_batch_size=4، - optim="adalomo"، - gradient_checkpointing=True، - logging_strategy="steps"، - logging_steps=1، - learning_rate=2e-6، - save_strategy="no"، - run_name="lomo-imdb"، +train_dataset = datasets.load_dataset('imdb', split='train') +args = SFTConfig( + output_dir="./test-lomo", + max_steps=100, + optim="adalomo", + gradient_checkpointing=True, ) - -model_id = "google/gemma-2b" - -tokenizer = AutoTokenizer.from_pretrained(model_id) -model = AutoModelForCausalLM.from_pretrained(model_id).to(0) - -trainer = trl.SFTTrainer( - model=model، - args=args، - train_dataset=train_dataset، - dataset_text_field='text'، - max_seq_length=1024، +trainer = SFTTrainer( + model="google/gemma-2b", + args=args, + train_dataset=train_dataset, ) - trainer.train() ``` @@ -524,39 +461,21 @@ trainer.train() فيما يلي نص برمجى بسيط لشرح كيفية ضبط [google/gemma-2b](https://huggingface.co/google/gemma-2b) بدقة على مجموعة بيانات IMDB بدقة كاملة: ```python -import torch import datasets -from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM -import trl +from trl import SFTConfig, SFTTrainer train_dataset = datasets.load_dataset('imdb', split='train') - -args = TrainingArguments( - output_dir="./test-schedulefree", - max_steps=1000, - per_device_train_batch_size=4, +args = SFTConfig( + output_dir="./test-galore", + max_steps=100, optim="schedule_free_adamw", gradient_checkpointing=True, - logging_strategy="steps", - logging_steps=1, - learning_rate=2e-6, - save_strategy="no", - run_name="sfo-imdb", ) - -model_id = "google/gemma-2b" - -tokenizer = AutoTokenizer.from_pretrained(model_id) -model = AutoModelForCausalLM.from_pretrained(model_id).to(0) - -trainer = trl.SFTTrainer( - model=model, +trainer = SFTTrainer( + model="google/gemma-2b", args=args, train_dataset=train_dataset, - dataset_text_field='text', - max_seq_length=1024, ) - trainer.train() ``` ## تسريع ومدرب diff --git a/docs/source/en/model_doc/mamba.md b/docs/source/en/model_doc/mamba.md index 12a6a01974a..9ce98d8516a 100644 --- a/docs/source/en/model_doc/mamba.md +++ b/docs/source/en/model_doc/mamba.md @@ -97,39 +97,22 @@ print(tokenizer.decode(output[0], skip_special_tokens=True)) - Mamba stacks `mixer` layers which are equivalent to `Attention` layers. You can find the main logic of Mamba in the `MambaMixer` class. - The example below demonstrates how to fine-tune Mamba with [PEFT](https://huggingface.co/docs/peft). - ```py - from datasets import load_dataset - from trl import SFTTrainer - from peft import LoraConfig - from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments - - model_id = "state-spaces/mamba-130m-hf" - tokenizer = AutoTokenizer.from_pretrained(model_id) - model = AutoModelForCausalLM.from_pretrained(model_id) - dataset = load_dataset("Abirate/english_quotes", split="train") - training_args = TrainingArguments( - output_dir="./results", - num_train_epochs=3, - per_device_train_batch_size=4, - logging_dir='./logs', - logging_steps=10, - learning_rate=2e-3 - ) - lora_config = LoraConfig( - r=8, - target_modules=["x_proj", "embeddings", "in_proj", "out_proj"], - task_type="CAUSAL_LM", - bias="none" - ) - trainer = SFTTrainer( - model=model, - processing_class=tokenizer, + ```py + from datasets import load_dataset + from trl import SFTConfig, SFTTrainer + from peft import LoraConfig + + model_id = "state-spaces/mamba-130m-hf" + dataset = load_dataset("Abirate/english_quotes", split="train") + training_args = SFTConfig(dataset_text_field="quote") + lora_config = LoraConfig(target_modules=["x_proj", "embeddings", "in_proj", "out_proj"]) + trainer = SFTTrainer( + model=model_id, args=training_args, - peft_config=lora_config, - train_dataset=dataset, - dataset_text_field="quote", - ) - trainer.train() + train_dataset=dataset, + peft_config=lora_config, + ) + trainer.train() ``` ## MambaConfig diff --git a/docs/source/en/model_doc/mamba2.md b/docs/source/en/model_doc/mamba2.md index 5a577983a74..a2094070226 100644 --- a/docs/source/en/model_doc/mamba2.md +++ b/docs/source/en/model_doc/mamba2.md @@ -103,40 +103,19 @@ print(tokenizer.decode(output[0], skip_special_tokens=True)) - The example below demonstrates how to fine-tune Mamba 2 with [PEFT](https://huggingface.co/docs/peft). ```python -from trl import SFTTrainer +from datasets import load_dataset from peft import LoraConfig -from transformers import AutoTokenizer, Mamba2ForCausalLM, TrainingArguments -model_id = 'mistralai/Mamba-Codestral-7B-v0.1' -tokenizer = AutoTokenizer.from_pretrained(model_id, revision='refs/pr/9', from_slow=True, legacy=False) -tokenizer.pad_token = tokenizer.eos_token -tokenizer.padding_side = "left" #enforce padding side left +from trl import SFTConfig, SFTTrainer -model = Mamba2ForCausalLM.from_pretrained(model_id, revision='refs/pr/9') +model_id = "mistralai/Mamba-Codestral-7B-v0.1" dataset = load_dataset("Abirate/english_quotes", split="train") -# Without CUDA kernels, batch size of 2 occupies one 80GB device -# but precision can be reduced. -# Experiments and trials welcome! -training_args = TrainingArguments( - output_dir="./results", - num_train_epochs=3, - per_device_train_batch_size=2, - logging_dir='./logs', - logging_steps=10, - learning_rate=2e-3 -) -lora_config = LoraConfig( - r=8, - target_modules=["embeddings", "in_proj", "out_proj"], - task_type="CAUSAL_LM", - bias="none" -) +training_args = SFTConfig(dataset_text_field="quote", gradient_checkpointing=True, per_device_train_batch_size=4) +lora_config = LoraConfig(target_modules=["x_proj", "embeddings", "in_proj", "out_proj"]) trainer = SFTTrainer( - model=model, - tokenizer=tokenizer, + model=model_id, args=training_args, - peft_config=lora_config, train_dataset=dataset, - dataset_text_field="quote", + peft_config=lora_config, ) trainer.train() ``` diff --git a/docs/source/en/trainer.md b/docs/source/en/trainer.md index 7174344487f..56f929884a5 100644 --- a/docs/source/en/trainer.md +++ b/docs/source/en/trainer.md @@ -392,15 +392,15 @@ training_args = TrainingArguments( [Gradient Low-Rank Projection (GaLore)](https://hf.co/papers/2403.03507) significantly reduces memory usage when training large language models (LLMs). One of GaLores key benefits is *full-parameter* learning, unlike low-rank adaptation methods like [LoRA](https://hf.co/papers/2106.09685), which produces better model performance. -Install the [GaLore](https://github.com/jiaweizzhao/GaLore) library, [TRL](https://hf.co/docs/trl/index), and [Datasets](https://hf.co/docs/datasets/index). +Install the [GaLore](https://github.com/jiaweizzhao/GaLore) and [TRL](https://hf.co/docs/trl/index) libraries. ```bash -pip install galore-torch trl datasets +pip install galore-torch trl ``` -Pick a GaLore optimizer (`"galore_adamw"`, `"galore_adafactor"`, `"galore_adamw_8bit`") and pass it to the `optim` parameter in [`TrainingArguments`]. Use the `optim_target_modules` parameter to specify which modules to adapt (can be a list of strings, regex, or a full path). +Pick a GaLore optimizer (`"galore_adamw"`, `"galore_adafactor"`, `"galore_adamw_8bit`") and pass it to the `optim` parameter in [`trl.SFTConfig`]. Use the `optim_target_modules` parameter to specify which modules to adapt (can be a list of strings, regex, or a full path). -Extra parameters supported by GaLore, `rank`, `update_proj_gap`, and `scale`, should be passed to the `optim_args` parameter in [`TrainingArguments`]. +Extra parameters supported by GaLore, `rank`, `update_proj_gap`, and `scale`, should be passed to the `optim_args` parameter in [`trl.SFTConfig`]. The example below enables GaLore with [`~trl.SFTTrainer`] that targets the `attn` and `mlp` layers with regex. @@ -411,29 +411,22 @@ The example below enables GaLore with [`~trl.SFTTrainer`] that targets the `attn ```py -import torch import datasets -import trl -from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM +from trl import SFTConfig, SFTTrainer train_dataset = datasets.load_dataset('imdb', split='train') -args = TrainingArguments( +args = SFTConfig( output_dir="./test-galore", max_steps=100, - per_device_train_batch_size=2, optim="galore_adamw", optim_target_modules=[r".*.attn.*", r".*.mlp.*"], optim_args="rank=64, update_proj_gap=100, scale=0.10", + gradient_checkpointing=True, ) -config = AutoConfig.from_pretrained("google/gemma-2b") -tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b") -model = AutoModelForCausalLM.from_config("google/gemma-2b").to(0) -trainer = trl.SFTTrainer( - model=model, +trainer = SFTTrainer( + model="google/gemma-2b", args=args, train_dataset=train_dataset, - dataset_text_field='text', - max_seq_length=512, ) trainer.train() ``` @@ -444,29 +437,22 @@ trainer.train() Append `layerwise` to the optimizer name to enable layerwise optimization. For example, `"galore_adamw"` becomes `"galore_adamw_layerwise"`. This feature is still experimental and does not support Distributed Data Parallel (DDP). The code below can only be run on a [single GPU](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory). Other features like gradient clipping and DeepSpeed may not be available out of the box. Feel free to open an [issue](https://github.com/huggingface/transformers/issues) if you encounter any problems! ```py -import torch import datasets -import trl -from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM +from trl import SFTConfig, SFTTrainer train_dataset = datasets.load_dataset('imdb', split='train') -args = TrainingArguments( +args = SFTConfig( output_dir="./test-galore", max_steps=100, - per_device_train_batch_size=2, optim="galore_adamw_layerwise", optim_target_modules=[r".*.attn.*", r".*.mlp.*"], optim_args="rank=64, update_proj_gap=100, scale=0.10", + gradient_checkpointing=True, ) -config = AutoConfig.from_pretrained("google/gemma-2b") -tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b") -model = AutoModelForCausalLM.from_config("google/gemma-2b").to(0) -trainer = trl.SFTTrainer( - model=model, +trainer = SFTTrainer( + model="google/gemma-2b", args=args, train_dataset=train_dataset, - dataset_text_field='text', - max_seq_length=512, ) trainer.train() ``` diff --git a/docs/source/ko/model_doc/mamba.md b/docs/source/ko/model_doc/mamba.md index 4c1b898f2db..001ea609932 100644 --- a/docs/source/ko/model_doc/mamba.md +++ b/docs/source/ko/model_doc/mamba.md @@ -58,34 +58,18 @@ print(tokenizer.batch_decode(out)) ```python from datasets import load_dataset -from trl import SFTTrainer +from trl import SFTConfig, SFTTrainer from peft import LoraConfig -from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments + model_id = "state-spaces/mamba-130m-hf" -tokenizer = AutoTokenizer.from_pretrained(model_id) -model = AutoModelForCausalLM.from_pretrained(model_id) dataset = load_dataset("Abirate/english_quotes", split="train") -training_args = TrainingArguments( - output_dir="./results", - num_train_epochs=3, - per_device_train_batch_size=4, - logging_dir='./logs', - logging_steps=10, - learning_rate=2e-3 -) -lora_config = LoraConfig( - r=8, - target_modules=["x_proj", "embeddings", "in_proj", "out_proj"], - task_type="CAUSAL_LM", - bias="none" -) +training_args = SFTConfig(dataset_text_field="quote") +lora_config = LoraConfig(target_modules=["x_proj", "embeddings", "in_proj", "out_proj"]) trainer = SFTTrainer( - model=model, - tokenizer=tokenizer, + model=model_id, args=training_args, - peft_config=lora_config, train_dataset=dataset, - dataset_text_field="quote", + peft_config=lora_config, ) trainer.train() ``` diff --git a/docs/source/ko/model_doc/mamba2.md b/docs/source/ko/model_doc/mamba2.md index c6af73ed955..04ef4d070b8 100644 --- a/docs/source/ko/model_doc/mamba2.md +++ b/docs/source/ko/model_doc/mamba2.md @@ -57,40 +57,19 @@ print(tokenizer.batch_decode(out)) 이곳은 미세조정을 위한 초안 스크립트입니다: ```python -from trl import SFTTrainer +from datasets import load_dataset from peft import LoraConfig -from transformers import AutoTokenizer, Mamba2ForCausalLM, TrainingArguments -model_id = 'mistralai/Mamba-Codestral-7B-v0.1' -tokenizer = AutoTokenizer.from_pretrained(model_id, revision='refs/pr/9', from_slow=True, legacy=False) -tokenizer.pad_token = tokenizer.eos_token -tokenizer.padding_side = "left" #왼쪽 패딩으로 설정 +from trl import SFTConfig, SFTTrainer -model = Mamba2ForCausalLM.from_pretrained(model_id, revision='refs/pr/9') +model_id = "mistralai/Mamba-Codestral-7B-v0.1" dataset = load_dataset("Abirate/english_quotes", split="train") -# CUDA 커널없이는, 배치크기 2가 80GB 장치를 하나 차지합니다. -# 하지만 정확도는 감소합니다. -# 실험과 시도를 환영합니다! -training_args = TrainingArguments( - output_dir="./results", - num_train_epochs=3, - per_device_train_batch_size=2, - logging_dir='./logs', - logging_steps=10, - learning_rate=2e-3 -) -lora_config = LoraConfig( - r=8, - target_modules=["embeddings", "in_proj", "out_proj"], - task_type="CAUSAL_LM", - bias="none" -) +training_args = SFTConfig(dataset_text_field="quote", gradient_checkpointing=True, per_device_train_batch_size=4) +lora_config = LoraConfig(target_modules=["x_proj", "embeddings", "in_proj", "out_proj"]) trainer = SFTTrainer( - model=model, - tokenizer=tokenizer, + model=model_id, args=training_args, - peft_config=lora_config, train_dataset=dataset, - dataset_text_field="quote", + peft_config=lora_config, ) trainer.train() ``` diff --git a/docs/source/ko/trainer.md b/docs/source/ko/trainer.md index 072e081a361..d753627c86f 100644 --- a/docs/source/ko/trainer.md +++ b/docs/source/ko/trainer.md @@ -267,75 +267,45 @@ pip install galore-torch 그런 다음 `optim`에 `["galore_adamw", "galore_adafactor", "galore_adamw_8bit"]` 중 하나와 함께 `optim_target_modules`를 추가합니다. 이는 적용하려는 대상 모듈 이름에 해당하는 문자열, 정규 표현식 또는 전체 경로의 목록일 수 있습니다. 아래는 end-to-end 예제 스크립트입니다(필요한 경우 `pip install trl datasets`를 실행): ```python -import torch import datasets -import trl - -from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM +from trl import SFTConfig, SFTTrainer train_dataset = datasets.load_dataset('imdb', split='train') - -args = TrainingArguments( +args = SFTConfig( output_dir="./test-galore", max_steps=100, - per_device_train_batch_size=2, optim="galore_adamw", - optim_target_modules=["attn", "mlp"] + optim_target_modules=[r".*.attn.*", r".*.mlp.*"], + gradient_checkpointing=True, ) - -model_id = "google/gemma-2b" - -config = AutoConfig.from_pretrained(model_id) - -tokenizer = AutoTokenizer.from_pretrained(model_id) -model = AutoModelForCausalLM.from_config(config).to(0) - -trainer = trl.SFTTrainer( - model=model, +trainer = SFTTrainer( + model="google/gemma-2b", args=args, train_dataset=train_dataset, - dataset_text_field='text', - max_seq_length=512, ) - trainer.train() ``` GaLore가 지원하는 추가 매개변수를 전달하려면 `optim_args`를 설정합니다. 예를 들어: ```python -import torch import datasets -import trl - -from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM +from trl import SFTConfig, SFTTrainer train_dataset = datasets.load_dataset('imdb', split='train') - -args = TrainingArguments( +args = SFTConfig( output_dir="./test-galore", max_steps=100, - per_device_train_batch_size=2, optim="galore_adamw", - optim_target_modules=["attn", "mlp"], + optim_target_modules=[r".*.attn.*", r".*.mlp.*"], optim_args="rank=64, update_proj_gap=100, scale=0.10", + gradient_checkpointing=True, ) - -model_id = "google/gemma-2b" - -config = AutoConfig.from_pretrained(model_id) - -tokenizer = AutoTokenizer.from_pretrained(model_id) -model = AutoModelForCausalLM.from_config(config).to(0) - -trainer = trl.SFTTrainer( - model=model, +trainer = SFTTrainer( + model="google/gemma-2b", args=args, train_dataset=train_dataset, - dataset_text_field='text', - max_seq_length=512, ) - trainer.train() ``` @@ -348,37 +318,22 @@ trainer.train() 다음과 같이 옵티마이저 이름에 `layerwise`를 추가하여 레이어별 최적화를 수행할 수도 있습니다: ```python -import torch import datasets -import trl - -from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM +from trl import SFTConfig, SFTTrainer train_dataset = datasets.load_dataset('imdb', split='train') - -args = TrainingArguments( +args = SFTConfig( output_dir="./test-galore", max_steps=100, - per_device_train_batch_size=2, optim="galore_adamw_layerwise", - optim_target_modules=["attn", "mlp"] + optim_target_modules=[r".*.attn.*", r".*.mlp.*"], + gradient_checkpointing=True, ) - -model_id = "google/gemma-2b" - -config = AutoConfig.from_pretrained(model_id) - -tokenizer = AutoTokenizer.from_pretrained(model_id) -model = AutoModelForCausalLM.from_config(config).to(0) - -trainer = trl.SFTTrainer( - model=model, +trainer = SFTTrainer( + model="google/gemma-2b", args=args, train_dataset=train_dataset, - dataset_text_field='text', - max_seq_length=512, ) - trainer.train() ``` @@ -398,39 +353,21 @@ LOMO 옵티마이저는 [제한된 자원으로 대형 언어 모델의 전체 다음은 IMDB 데이터셋에서 [google/gemma-2b](https://huggingface.co/google/gemma-2b)를 최대 정밀도로 미세 조정하는 간단한 스크립트입니다: ```python -import torch import datasets -from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM -import trl +from trl import SFTConfig, SFTTrainer train_dataset = datasets.load_dataset('imdb', split='train') - -args = TrainingArguments( +args = SFTConfig( output_dir="./test-lomo", - max_steps=1000, - per_device_train_batch_size=4, + max_steps=100, optim="adalomo", gradient_checkpointing=True, - logging_strategy="steps", - logging_steps=1, - learning_rate=2e-6, - save_strategy="no", - run_name="lomo-imdb", ) - -model_id = "google/gemma-2b" - -tokenizer = AutoTokenizer.from_pretrained(model_id) -model = AutoModelForCausalLM.from_pretrained(model_id).to(0) - -trainer = trl.SFTTrainer( - model=model, +trainer = SFTTrainer( + model="google/gemma-2b", args=args, train_dataset=train_dataset, - dataset_text_field='text', - max_seq_length=1024, ) - trainer.train() ```