TRL documentation
Speeding Up Training
Speeding Up Training
This guide covers various methods to accelerate training in TRL. Each technique includes minimal examples with links to more comprehensive documentation.
vLLM for fast generation in online methods
Online methods such as GRPO or Online DPO require the model to generate completions, which is often a slow process and can significantly impact training time. To speed up generation, you can use vLLM, a library that enables fast generation through, among other things, PagedAttention. TRL’s online trainers support vLLM, greatly improving training speed. For more details, see vLLM Integration.
To use vLLM, first install it using:
pip install trl[vllm]
First, start a vLLM server by running:
trl vllm-serve --model <model_name>
Then, run the training script and pass use_vllm=True in the training arguments.
from trl.experimental.online_dpo import OnlineDPOConfig
training_args = OnlineDPOConfig(..., use_vllm=True)Optimized attention implementations
TRL supports various optimized attention implementations that can significantly speed up training while reducing memory usage. You can use either a pre-optimized kernels directly from the Kernels Hub or a manually built attention backend.
You can use pre-optimized attention kernels from the Hub without manual compilation:
from trl import SFTConfig
training_args = SFTConfig(..., model_init_kwargs={"attn_implementation": "kernels-community/flash-attn2"})Other options include kernels-community/vllm-flash-attn3 and kernels-community/paged-attention.
Optimized attention works across all TRL trainers. For more details, see Kernels Hub Integration.
Liger Kernel for memory optimization
Liger Kernel is a collection of Triton kernels designed for LLM training that can increase throughput by 20% and reduce memory usage by 60%.
from trl import SFTConfig
training_args = SFTConfig(..., use_liger_kernel=True)For more information, see Liger Kernel Integration.
Mixed precision training
Mixed precision training using bf16 or fp16 can speed up training and reduce memory usage with minimal impact on model quality.
from trl import SFTConfig
training_args = SFTConfig(..., bf16=True) # or fp16=True for older GPUsUse bf16=True for Ampere GPUs (A100, RTX 30xx) or newer, and fp16=True for older GPUs. Mixed precision training is supported across all TRL trainers.