SFT/RL by trl
How to use trl
最近仕事でtrlを使う機会があったので備忘録として以下に記す。
For Supervised Fine-tuning
SFTTrainerを使う。modelとtokenizerを読み込んだ後、以下のようにconfigとtrainerクラスを設定して学習させる。
from trl import SFTConfig, SFTTrainer
training_args = SFTConfig(
learning_rate=2e-4,
gradient_checkpointing=True,
num_train_epochs=10,
logging_steps=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
max_length=16384,
warmup_ratio=0.03,
lr_scheduler_type="cosine_with_min_lr",
lr_scheduler_kwargs={"min_lr_rate": 0.1},
output_dir=your_session_name,
report_to="trackio",
push_to_hub=True,
)
trainer = SFTTrainer(
model=your_model,
args=training_args,
train_dataset=your_dataset,
processing_class=your_tokenizer,
)
trainer.train()
データセットはconversationalの場合は、以下を一要素としたリストを用意する。
{"messages":[
{"role": "system", "content": your_system_prompt},
{"role": "user", "content": your_user_prompt},
{"role": "assistant", "content": target_response}
]
}
その後、以下のコードでSFTTrainerが扱えるフォーマットに変換する。
from datasets import Dataset
def convert_to_text(example):
messages = example["messages"]
formatted = ""
image_paths = []
for m in messages:
if m["role"] == "system":
formatted += f"<|system|>\n{m['content']}\n"
elif m["role"] == "user":
formatted += f"<|user|>\n{m['content']}\n"
elif m["role"] == "assistant":
formatted += f"<|assistant|>\n{m['content']}\n"
return {"text": formatted}
sft_train_dataset = Dataset.from_list(your_list_of_json)
sft_train_dataset = sft_train_dataset.map(convert_to_text)
For Reinforcement Learning
GRPOTrainerを例とする。SFTと同様にconfigとtrainerクラスを設定。
報酬計算メソッドは独自で設定可能であり、以下は生成テキスト (completions)と正解テキスト (ground_truth)のRougeLスコアを計算して、一定値以上の場合に報酬値を返すメソッドである。
import evaluate
from trl import GRPOConfig, GRPOTrainer
def reward_func(completions, ground_truth, **kwargs):
rewards = []
for completion, gt in zip(completions, ground_truth):
completion = completion[0]['content']
ROUGE_SCORE = evaluate.load("rouge")
score = ROUGE_SCORE.compute(
predictions=[completion],
references=[gt],
rouge_types=["rouge1"],
)["rouge1"]
if score > 0.4: reward = 1
else: reward = 0
rewards.append(reward)
return rewards
training_args = GRPOConfig(
output_dir=self.output_model_name,
learning_rate=2e-5,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
max_prompt_length=16384,
max_completion_length=512,
num_generations=2,
num_train_epochs=10,
bf16=True,
remove_unused_columns=False,
logging_steps=1,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
temperature=0.2
)
trainer = GRPOTrainer(
model=your_model,
reward_funcs=reward_func,
args=training_args,
train_dataset=your_train_dataset,
processing_class=your_tokenizer
)
trainer.train()
configのper_device_train_batch_sizeには注意が必要で、こちらのissueにより現在はデバイス毎の生成数 (≠ バッチサイズ)になっている。
データセットの作り方はSFTと同様。今回の報酬関数のように正解テキストを用いる場合は、要素となるDictionaryに報酬関数の引数値と同じキー (例: ground_truth)を設定する。