TRL documentation
Distillation Trainer
Distillation Trainer
Overview
The Distillation Trainer implements on-policy knowledge distillation as described in On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes by Rishabh Agarwal, Nino Vieillard, Yongchao Zhou, Piotr Stanczyk, Sabela Ramos, Matthieu Geist, and Olivier Bachem.
Knowledge distillation (KD) is widely used for compressing a teacher model to reduce its inference cost and memory footprint, by training a smaller student model. However, current KD methods for auto-regressive sequence models suffer from distribution mismatch between output sequences seen during training and those generated by the student during inference. To address this issue, we introduce Generalized Knowledge Distillation (GKD). Instead of solely relying on a fixed set of output sequences, GKD trains the student on its self-generated output sequences by leveraging feedback from the teacher on such sequences. Unlike supervised KD approaches, GKD also offers the flexibility to employ alternative loss functions between the student and teacher, which can be useful when the student lacks the expressivity to mimic the teacher’s distribution.
The DistillationTrainer is designed for distilling teacher models of all sizes into smaller students efficiently. It extends the ideas from the GKDTrainer with three key optimizations:
- Generation buffer – decouples the training microbatch size from the generation batch size, letting vLLM batch many prompts in a single call across gradient accumulation steps. This alone can speed up training by up to 40x.
- Teacher server support – moves the teacher to an external vLLM server so it does not need to fit on the same GPUs as the student.
- Binary-encoded logprob payloads – packs log-probabilities into base64-encoded NumPy arrays instead of nested JSON lists, shrinking transfer payloads by ~5x.
The Distillation Trainer is currently part of the
trl.experimentalnamespace. APIs may change without notice while the feature is iterated on.
Quick start
from datasets import load_dataset
from trl.experimental.distillation import DistillationConfig, DistillationTrainer
# 1. Load dataset and format as prompt-only chat messages
dataset = load_dataset("openai/gsm8k", "main", split="train")
dataset = dataset.map(
lambda x: {"messages": [{"role": "user", "content": x["question"]}]},
remove_columns=dataset.column_names,
)
# 2. Configure distillation
config = DistillationConfig(
output_dir="results/distill-qwen-gsm8k",
num_train_epochs=1,
bf16=True,
save_strategy="no",
# Distillation
lmbda=1.0, # fully on-policy (student generates)
beta=1.0, # reverse KL
# Teacher
teacher_model_init_kwargs={"torch_dtype": "bfloat16"},
)
# 3. Train
trainer = DistillationTrainer(
model="Qwen/Qwen2.5-1.5B-Instruct",
teacher_model="Qwen/Qwen2.5-7B-Instruct",
args=config,
train_dataset=dataset,
)
trainer.train()
trainer.save_model()Usage tips
The experimental.distillation.DistillationTrainer needs three key parameters set via experimental.distillation.DistillationConfig:
lmbda: controls the student data fraction, i.e., the proportion of on-policy student-generated outputs. Whenlmbda=0.0, training is fully off-policy (dataset completions only). Whenlmbda=1.0, training is fully on-policy (student generates all completions). For values in between, each gradient accumulation slice is randomly assigned as on- or off-policy based onlmbda.beta: controls the interpolation in the Generalized Jensen-Shannon Divergence. Whenbeta=0.0the loss approximates forward KL divergence, whilebeta=1.0approximates reverse KL divergence. Values in between interpolate.loss_top_k: number of top tokens to use for the KL/JSD loss. Set to0for exact full-vocabulary computation (local teacher only), or> 0for a top-k approximation. See more about top-k with external teacher server below.
On-policy vs. off-policy
Setting lmbda=1.0 (fully on-policy) generally outperforms off-policy distillation because the student learns from its own mistakes rather than imitating trajectories it may never produce. The generation buffer ensures on-policy training stays efficient: prompts across gradient accumulation steps are batched into a single vLLM call.
Using an external teacher server
For teachers that do not fit on training GPUs (e.g., 100B+ parameters), host the teacher on a separate vLLM server and set use_teacher_server=True with teacher_model_server_url:
config = DistillationConfig(
output_dir="distilled-model",
use_teacher_server=True,
teacher_model_server_url="http://teacher-host:8000",
loss_top_k=1, # required with teacher server when beta > 0
beta=1.0,
lmbda=1.0,
)
trainer = DistillationTrainer(
model="Qwen/Qwen3-4B",
args=config,
train_dataset=dataset,
)
trainer.train()When using the teacher server:
loss_top_kmust be> 0whenbeta=0.0(forward KL)loss_top_kmust be exactly1whenbeta > 0(reverse KL or JSD)reverse_kl_top_1_mode="argmax"is not supported- Liger kernel is not supported
Expected dataset type
The dataset should be formatted as a conversational language modeling dataset:
{"messages": [{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."}]}When using fully on-policy distillation (lmbda=1.0), the assistant turn can be omitted since the student will generate its own completions:
{"messages": [{"role": "user", "content": "What color is the sky?"}]}DistillationTrainer
class trl.experimental.distillation.DistillationTrainer
< source >( model: transformers.modeling_utils.PreTrainedModel | torch.nn.modules.module.Module | str | None = None teacher_model: transformers.modeling_utils.PreTrainedModel | torch.nn.modules.module.Module | str = None args: trl.experimental.distillation.distillation_config.DistillationConfig | None = None data_collator: collections.abc.Callable[[list[typing.Any]], dict[str, typing.Any]] | None = None train_dataset: datasets.arrow_dataset.Dataset | None = None eval_dataset: datasets.arrow_dataset.Dataset | dict[str, datasets.arrow_dataset.Dataset] | None = None processing_class: transformers.tokenization_utils_base.PreTrainedTokenizerBase | transformers.image_processing_utils.BaseImageProcessor | transformers.feature_extraction_utils.FeatureExtractionMixin | transformers.processing_utils.ProcessorMixin | None = None compute_metrics: collections.abc.Callable[[transformers.trainer_utils.EvalPrediction], dict] | None = None callbacks: list[transformers.trainer_callback.TrainerCallback] | None = None optimizers: tuple = (None, None) preprocess_logits_for_metrics: collections.abc.Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None peft_config: typing.Optional[ForwardRef('PeftConfig')] = None )
Trainer for knowledge distillation from a teacher model to a student model.
Supports:
- Generalized JSD loss (forward KL, reverse KL, or interpolated JSD via
beta) - On-policy / off-policy mixing via
lmbda(buffered across gradient accumulation) - Local teacher model or external teacher via vLLM server
- Student on-policy generation via vLLM or model.generate()
- Liger kernel for memory-efficient fused JSD loss
train
< source >( resume_from_checkpoint: str | bool | None = None trial: optuna.Trial | dict[str, Any] | None = None ignore_keys_for_eval: list[str] | None = None ) → ~trainer_utils.TrainOutput
Parameters
- resume_from_checkpoint (
strorbool, optional) — If astr, local path to a saved checkpoint as saved by a previous instance ofTrainer. If abooland equalsTrue, load the last checkpoint in args.output_dir as saved by a previous instance ofTrainer. If present, training will resume from the model/optimizer/scheduler states loaded here. - trial (
optuna.Trialordict[str, Any], optional) — The trial run or the hyperparameter dictionary for hyperparameter search. - ignore_keys_for_eval (
list[str], optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training.
Returns
~trainer_utils.TrainOutput
Object containing the global step count, training loss, and metrics.
Main training entry point.
Will save the model, so you can reload it using from_pretrained().
Will only save from the main process.
push_to_hub
< source >( commit_message: str | None = 'End of training' blocking: bool = True token: str | None = None revision: str | None = None **kwargs )
Parameters
- commit_message (
str, optional, defaults to"End of training") — Message to commit while pushing. - blocking (
bool, optional, defaults toTrue) — Whether the function should return only when thegit pushhas finished. - token (
str, optional, defaults toNone) — Token with write permission to overwrite Trainer’s original args. - revision (
str, optional) — The git revision to commit from. Defaults to the head of the “main” branch. - kwargs (
dict[str, Any], optional) — Additional keyword arguments passed along to~Trainer.create_model_card.
Upload self.model and self.processing_class to the 🤗 model hub on the repo self.args.hub_model_id.
DistillationConfig
class trl.experimental.distillation.DistillationConfig
< source >( output_dir: str | None = None per_device_train_batch_size: int = 8 num_train_epochs: float = 3.0 max_steps: int = -1 learning_rate: float = 1e-06 lr_scheduler_type: transformers.trainer_utils.SchedulerType | str = 'linear' lr_scheduler_kwargs: dict | str | None = None warmup_steps: float = 0 optim: transformers.training_args.OptimizerNames | str = 'adamw_torch_fused' optim_args: str | None = None weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 optim_target_modules: None | str | list[str] = None gradient_accumulation_steps: int = 1 average_tokens_across_devices: bool = True max_grad_norm: float = 1.0 label_smoothing_factor: float = 0.0 bf16: bool | None = None fp16: bool = False bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: bool | None = None gradient_checkpointing: bool = True gradient_checkpointing_kwargs: dict[str, typing.Any] | str | None = None torch_compile: bool = False torch_compile_backend: str | None = None torch_compile_mode: str | None = None use_liger_kernel: bool = False liger_kernel_config: dict[str, bool] | None = None use_cache: bool = False neftune_noise_alpha: float | None = None torch_empty_cache_steps: int | None = None auto_find_batch_size: bool = False logging_strategy: transformers.trainer_utils.IntervalStrategy | str = 'steps' logging_steps: float = 10 logging_first_step: bool = False log_on_each_node: bool = True logging_nan_inf_filter: bool = True include_num_input_tokens_seen: str | bool = 'no' log_level: str = 'passive' log_level_replica: str = 'warning' disable_tqdm: bool | None = None report_to: None | str | list[str] = 'none' run_name: str | None = None project: str = 'huggingface' trackio_space_id: str | None = 'trackio' eval_strategy: transformers.trainer_utils.IntervalStrategy | str = 'no' eval_steps: float | None = None eval_delay: float = 0 per_device_eval_batch_size: int = 8 prediction_loss_only: bool = False eval_on_start: bool = False eval_do_concat_batches: bool = True eval_use_gather_object: bool = False eval_accumulation_steps: int | None = None include_for_metrics: list = <factory> batch_eval_metrics: bool = False save_only_model: bool = False save_strategy: transformers.trainer_utils.SaveStrategy | str = 'steps' save_steps: float = 500 save_on_each_node: bool = False save_total_limit: int | None = None enable_jit_checkpoint: bool = False push_to_hub: bool = False hub_token: str | None = None hub_private_repo: bool | None = None hub_model_id: str | None = None hub_strategy: transformers.trainer_utils.HubStrategy | str = 'every_save' hub_always_push: bool = False hub_revision: str | None = None load_best_model_at_end: bool = False metric_for_best_model: str | None = None greater_is_better: bool | None = None ignore_data_skip: bool = False restore_callback_states_from_checkpoint: bool = False full_determinism: bool = False seed: int = 42 data_seed: int | None = None use_cpu: bool = False accelerator_config: dict | str | None = None parallelism_config: accelerate.parallelism_config.ParallelismConfig | None = None dataloader_drop_last: bool = False dataloader_num_workers: int = 0 dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False dataloader_prefetch_factor: int | None = None remove_unused_columns: bool = True label_names: list[str] | None = None train_sampling_strategy: str = 'random' length_column_name: str = 'length' ddp_find_unused_parameters: bool | None = None ddp_bucket_cap_mb: int | None = None ddp_broadcast_buffers: bool | None = None ddp_backend: str | None = None ddp_timeout: int = 1800 fsdp: list[transformers.trainer_utils.FSDPOption] | str | None = None fsdp_config: dict[str, typing.Any] | str | None = None deepspeed: dict | str | None = None debug: str | list[transformers.debug_utils.DebugOption] = '' skip_memory_metrics: bool = True do_train: bool = False do_eval: bool = False do_predict: bool = False resume_from_checkpoint: str | None = None warmup_ratio: float | None = None logging_dir: str | None = None local_rank: int = -1 model_init_kwargs: dict[str, typing.Any] | str | None = None max_length: int | None = 1024 temperature: float = 1.0 lmbda: float = 1.0 beta: float = 1.0 reverse_kl_top_1_mode: str = 'sampled' max_completion_length: int = 512 max_prompt_length: int | None = None disable_dropout: bool = True teacher_model_name_or_path: str | None = None teacher_model_revision: str | None = None teacher_model_init_kwargs: dict[str, typing.Any] | str | None = None use_teacher_server: bool = False teacher_model_server_url: str | None = None loss_top_k: int = 1 loss_add_tail: bool = True num_generations: int = 1 generation_batch_size: int | None = None top_p: float = 0.95 top_k: int = 0 use_vllm: bool = False vllm_mode: str = 'colocate' vllm_server_base_url: str | None = None vllm_server_host: str = '0.0.0.0' vllm_server_port: int = 8001 vllm_server_timeout: float = 240.0 vllm_group_port: int = 51216 vllm_gpu_memory_utilization: float = 0.3 vllm_tensor_parallel_size: int = 1 vllm_max_model_length: int | None = None vllm_model_impl: str = 'vllm' vllm_structured_outputs_regex: str | None = None vllm_sync_frequency: int = 1 vllm_enable_sleep_mode: bool = False wandb_entity: str | None = None wandb_project: str | None = None wandb_run_group: str | None = None log_completions: bool = False log_completions_steps: int = 100 num_completions_to_print: int | None = None )
Parameters that control the model
- model_init_kwargs (
dict[str, Any], optional) — Keyword arguments forAutoModelForCausalLM.from_pretrained, used when themodelargument of the trainer is provided as a string. - max_length (
intorNone, optional, defaults to1024) — Maximum total sequence length (prompt + completion) for tokenization and truncation.
Parameters that control the distillation
- temperature (
float, optional, defaults to1.0) — Temperature for sampling during generation and for computing the distillation loss. Higher values produce softer probability distributions. - lmbda (
float, optional, defaults to1.0) — Probability of using on-policy (student-generated) data for each gradient accumulation slice. A value of0.0means fully off-policy (dataset completions only),1.0means fully on-policy. - beta (
float, optional, defaults to1.0) — Interpolation coefficient for the Generalized Jensen-Shannon Divergence loss. When0.0, the loss is the forward KL divergence. When1.0, the loss is the reverse KL divergence. When0.5, it is the standard JSD. - reverse_kl_top_1_mode (
str, optional, defaults to"sampled") — Selection rule for the reverse-KL top-1 token whenbeta > 0andloss_top_k == 1."sampled"uses the actual completion token in the batch."argmax"uses the student’s highest-probability token. This setting does not affect the forward-KL support, which always uses the teacher’s top-1 token. Ignored whenbeta == 0orloss_top_k != 1. - max_completion_length (
int, optional, defaults to512) — Maximum number of tokens to generate per completion during on-policy generation. - disable_dropout (
bool, optional, defaults toTrue) — Whether to disable dropout in the student model during training.
Parameters that control the teacher model
- teacher_model_name_or_path (
strorNone, optional) — Model name or path for the teacher model. Used when the teacher is loaded locally. - teacher_model_revision (
strorNone, optional) — Model revision of the teacher model (e.g., branch name, tag, or commit hash). - teacher_model_init_kwargs (
dict[str, Any]orNone, optional) — Keyword arguments passed toAutoModelForCausalLM.from_pretrainedwhen instantiating the teacher model from a string. - use_teacher_server (
bool, optional, defaults toFalse) — Whether to use an external vLLM teacher server instead of a local teacher model. - teacher_model_server_url (
strorNone, optional) — Base URL of a vLLM server hosting the teacher model (e.g.,"http://localhost:8000"). When set, teacher logprobs are fetched from the server instead of running a local forward pass whenuse_teacher_server=True. - loss_top_k (
int, optional, defaults to1) — Number of top tokens to use when computing the JSD/KL loss. Both student and teacher distributions are restricted to these K tokens and re-normalized before computing divergence. If 0, the full vocabulary is used. For local teachers, the general support rule is teacher top-k for forward KL, student top-k for reverse KL, and the union for mixed JSD. Whenbeta > 0andloss_top_k == 1, the forward support still uses the teacher’s top-1 token, while the reverse top-1 token is controlled byreverse_kl_top_1_mode. Whenuse_teacher_server=True, the pure forward path (beta=0) requires this to be positive and uses the teacher’s top-k logprobs for the forward term. Whenbeta > 0, server-backed distillation requiresloss_top_k == 1and only supports"sampled"reverse top-1 tokens. - loss_add_tail (
bool, optional, defaults toTrue) — Whether to append a tail bucket that represents the remaining probability mass outside the selected top-k support when computing the loss.
Parameters that control on-policy generation
- num_generations (
int, optional, defaults to1) — Number of completions to generate per prompt during on-policy generation. - generation_batch_size (
intorNone, optional) — Number of unique prompts per worker per optimizer step. IfNone, computed from(per_device_train_batch_size * gradient_accumulation_steps) // num_generations. - top_p (
float, optional, defaults to0.95) — Top-p (nucleus) sampling parameter for on-policy generation. - top_k (
int, optional, defaults to0) — Top-k sampling parameter for on-policy generation.0disables top-k filtering.
Parameters that control vLLM for student generation
- use_vllm (
bool, optional, defaults toFalse) — Whether to use vLLM for generating on-policy completions from the student model. - vllm_mode (
str, optional, defaults to"colocate") — Mode for student vLLM integration. Either"server"or"colocate". - vllm_server_base_url (
strorNone, optional) — Base URL for the student vLLM server. If provided,vllm_server_hostandvllm_server_portare ignored. - vllm_server_host (
str, optional, defaults to"0.0.0.0") — Host of the student vLLM server. - vllm_server_port (
int, optional, defaults to8001) — Port of the student vLLM server. - vllm_server_timeout (
float, optional, defaults to240.0) — Timeout for connecting to the student vLLM server. - vllm_group_port (
int, optional, defaults to51216) — Port for the vLLM weight-update group (NCCL communicator). - vllm_gpu_memory_utilization (
float, optional, defaults to0.3) — GPU memory utilization for the colocated student vLLM engine. - vllm_tensor_parallel_size (
int, optional, defaults to1) — Tensor parallel size for the colocated student vLLM engine. - vllm_max_model_length (
intorNone, optional) — Maximum model sequence length for the colocated vLLM engine. - vllm_model_impl (
str, optional, defaults to"vllm") — Model implementation backend for vLLM. Use"vllm"or"transformers". - vllm_structured_outputs_regex (
strorNone, optional) — Regex pattern for vLLM structured outputs. - vllm_sync_frequency (
int, optional, defaults to1) — Frequency (in training steps) to synchronize student model weights to the vLLM engine. - vllm_enable_sleep_mode (
bool, optional, defaults toFalse) — Enable vLLM sleep mode to offload student weights during the optimizer step.
Parameters that control logging
- log_completions (
bool, optional, defaults toFalse) — Whether to log a sample of (prompt, completion) pairs everylog_completions_stepssteps. Ifrichis installed, it prints the sample. Ifwandband/ortrackiologging is enabled, it logs it towandband/ortrackio. - log_completions_steps (
int, optional, defaults to100) — Number of steps between logging completions. Only used iflog_completionsisTrue. - num_completions_to_print (
intorNone, optional) — Number of completions to print. IfNone, all completions are logged.
Configuration class for the DistillationTrainer.
Extends TrainingArguments with parameters specific to knowledge distillation. This config is independent of SFTConfig — all necessary fields are declared here.
Using HfArgumentParser we can turn this class into argparse arguments that can be specified on the command line.