Skip to the content.
HOME

SFT/RL by trl

How to use trl

最近仕事でtrlを使う機会があったので備忘録として以下に記す。

For Supervised Fine-tuning

SFTTrainerを使う。modelとtokenizerを読み込んだ後、以下のようにconfigとtrainerクラスを設定して学習させる。

from trl import SFTConfig, SFTTrainer

training_args = SFTConfig(
    learning_rate=2e-4,
    gradient_checkpointing=True,
    num_train_epochs=10,
    logging_steps=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    max_length=16384,
    warmup_ratio=0.03,
    lr_scheduler_type="cosine_with_min_lr",
    lr_scheduler_kwargs={"min_lr_rate": 0.1},
    output_dir=your_session_name,
    report_to="trackio",
    push_to_hub=True,
)

trainer = SFTTrainer(
    model=your_model,
    args=training_args,
    train_dataset=your_dataset,
    processing_class=your_tokenizer,
)

trainer.train()

データセットはconversationalの場合は、以下を一要素としたリストを用意する。

{"messages":[
            {"role": "system", "content": your_system_prompt},
            {"role": "user", "content": your_user_prompt},
            {"role": "assistant", "content": target_response}
    ]
}

その後、以下のコードでSFTTrainerが扱えるフォーマットに変換する。

from datasets import Dataset

def convert_to_text(example):
    messages = example["messages"]
    formatted = ""
    image_paths = []
    for m in messages:
        if m["role"] == "system":
            formatted += f"<|system|>\n{m['content']}\n"
        elif m["role"] == "user":
            formatted += f"<|user|>\n{m['content']}\n"
        elif m["role"] == "assistant":
            formatted += f"<|assistant|>\n{m['content']}\n"
    return {"text": formatted}

sft_train_dataset = Dataset.from_list(your_list_of_json)
sft_train_dataset = sft_train_dataset.map(convert_to_text)

For Reinforcement Learning

GRPOTrainerを例とする。SFTと同様にconfigとtrainerクラスを設定。 報酬計算メソッドは独自で設定可能であり、以下は生成テキスト (completions)と正解テキスト (ground_truth)のRougeLスコアを計算して、一定値以上の場合に報酬値を返すメソッドである。

import evaluate
from trl import GRPOConfig, GRPOTrainer

def reward_func(completions, ground_truth, **kwargs):
    rewards = []
    for completion, gt in zip(completions, ground_truth):
        completion = completion[0]['content']
        ROUGE_SCORE = evaluate.load("rouge")

        score = ROUGE_SCORE.compute(
            predictions=[completion],
            references=[gt],
            rouge_types=["rouge1"],
        )["rouge1"]

        if score > 0.4: reward = 1
        else: reward = 0

        rewards.append(reward)

    return rewards

training_args = GRPOConfig(
    output_dir=self.output_model_name,
    learning_rate=2e-5,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    max_prompt_length=16384,
    max_completion_length=512,
    num_generations=2,
    num_train_epochs=10,
    bf16=True,
    remove_unused_columns=False,
    logging_steps=1,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    temperature=0.2
)

trainer = GRPOTrainer(
    model=your_model,
    reward_funcs=reward_func,
    args=training_args,
    train_dataset=your_train_dataset,
    processing_class=your_tokenizer
)
trainer.train()

configのper_device_train_batch_sizeには注意が必要で、こちらのissueにより現在はデバイス毎の生成数 (≠ バッチサイズ)になっている。 データセットの作り方はSFTと同様。今回の報酬関数のように正解テキストを用いる場合は、要素となるDictionaryに報酬関数の引数値と同じキー (例: ground_truth)を設定する。

Other modules

  • unsloth: trlと互換性があり、より学習パフォーマンスが向上する。
  • verl: trlとは別の強化学習によるファインチューニングに特化したモジュール。Multi-turn RLに対応している。