Implementing Knowledge Distillation with YOLO11N Student and YOLO11M Teacher in Ultralytics Trainer

Hi Ultralytics community,

I am working on a lightweight model based on YOLO11N. I didn’t change the detect layer, but as expected, the accuracy dropped compared to the original model. To recover the lost performance, I plan to apply knowledge distillation (KD) using YOLO11M as the teacher.

I created a custom trainer by inheriting from DetectionTrainer:

from ultralytics.models.yolo.detect import DetectionTrainer

class KDDetectionTrainer(DetectionTrainer):
    def __init__(self, cfg=None, overrides=None, distiller='mgd', loss_weight=1.0, _callbacks=None, teacher=None, student=None):
        super().__init__(cfg=cfg, overrides=overrides, _callbacks=_callbacks)
        self.model = student
        self.teacher = teacher
        self.teacher.model.eval()
        for param in self.teacher.model.parameters():
            param.requires_grad = False

And I instantiated it like this:

from ultralytics import YOLO
from ultralytics.models.yolo.detect.KDDetectionTrainer import KDDetectionTrainer
from ultralytics.utils import DEFAULT_CFG

student_model = YOLO("yolo11n.yaml")
teacher_model = YOLO(r"C:\Users\Hunger\Desktop\ultralytics\Custom_Distiller\best.pt")

args = dict(data="coco8.yaml", epochs=3)
trainer = KDDetectionTrainer(cfg=DEFAULT_CFG, student=student_model, teacher=teacher_model, overrides=args)

trainer.train()

My main question is: how can I access the outputs of the student and teacher during training so I can compute a custom distillation loss? I want to combine the original YOLO loss with the KD loss, but I’m not sure what is the proper way to hook into the training loop of DetectionTrainer to get the intermediate outputs.

Also, if anyone has similar implementations of knowledge distillation with Ultralytics YOLO models, it would be really helpful if you could share them for reference.

Any guidance or example on extending DetectionTrainer for knowledge distillation would be greatly appreciated!

Thanks in advance!

1 Like

You can override the loss method:

3 Likes

Thank U, is that easy.

@Toxite is this OK ?

class KDDetectionLoss(v8DetectionLoss):

    def __init__(
        self,
        model,
        teacher,
        tal_topk=10,
        kd_weight=1.0,
    ):
        super().__init__(model, tal_topk)

        self.model = model
        self.teacher = teacher
        self.kd_weight = kd_weight

        # 找 Detect
        self.student_detect = self._get_detect(model.model)
        self.teacher_detect = self._get_detect(teacher.model)

        # 让 Detect 缓存 P3/P4/P5
        self.student_detect.kd_collect = True
        self.teacher_detect.kd_collect = True

        # 冻结 teacher
        self.teacher.eval()
        for p in self.teacher.parameters():
            p.requires_grad = False

    @staticmethod
    def _get_detect(model):
        for m in model.modules():
            if getattr(m, "is_detect", False):
                return m
        raise RuntimeError("Detect module not found")

    def __call__(self, batch, preds):
        """
        batch: dict
        preds: student forward output
        """

        # 1️⃣ 原始 YOLO detection loss
        det_loss, loss_items = super().__call__(preds, batch)

        # 2️⃣ forward teacher(必须手动)
        with torch.no_grad():
            _ = self.teacher(batch["img"])

        # 3️⃣ 取 P3/P4/P5
        s_feats = getattr(self.student_detect, "_kd_feat", None)
        t_feats = getattr(self.teacher_detect, "_kd_feat", None)

        if s_feats is None or t_feats is None:
            return det_loss, loss_items

        # 4️⃣ Feature KD loss
        kd_loss = 0.0
        for fs, ft in zip(s_feats, t_feats):
            kd_loss += F.mse_loss(fs, ft.detach())

        total_loss = det_loss + self.kd_weight * kd_loss

        return total_loss, loss_items

Detect.forward

def forward(self, x: list[torch.Tensor]) -> list[torch.Tensor] | tuple:
    """Concatenate and return predicted bounding boxes and class probabilities."""
    if getattr(self, "kd_collect", False):
        self._kd_feat = x  # x = [P3, P4, P5]
    if self.end2end:
        return self.forward_end2end(x)

    for i in range(self.nl):
        x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
    if self.training:  # Training path
        return x
    y = self._inference(x)
    return y if self.export else (y, x)

It’s a good direction (feature KD on the Detect inputs), but I wouldn’t merge it as-is.

Right now you’re patching Detect.forward() and caching x by reference (self._kd_feat = x). That list then gets modified in-place (x[i] = torch.cat(...)), so what you later read may not be the “pre-head P3/P4/P5” features you intended. Also, YOLO11n vs YOLO11m feature channel sizes don’t match, so a straight F.mse_loss(fs, ft) will usually error unless you add a small adapter/projection.

A cleaner pattern in Ultralytics is to override the model’s loss() (as suggested) and capture features with a forward pre-hook on the Detect module, so you don’t have to edit core modules. See how BaseModel.loss() is structured in the docs for where this plugs in: the BaseModel loss() reference and the Advanced Customization guide are the two relevant entry points.

Minimal sketch (key ideas only): use the underlying PyTorch modules (YOLO(...).model), use a pre-hook to grab Detect inputs, and add an adapter if channels differ.

import torch
import torch.nn.functional as F

# teacher/student should be torch.nn.Module models, i.e. YOLO(...).model
student = YOLO("yolo11n.yaml").model
teacher = YOLO("yolo11m.pt").model  # or your best.pt
teacher.eval()
for p in teacher.parameters():
    p.requires_grad = False

s_feats, t_feats = None, None

def save_student_feats(m, inp):  # inp[0] is the list [P3,P4,P5] passed into Detect
    nonlocal s_feats
    s_feats = inp[0]

def save_teacher_feats(m, inp):
    nonlocal t_feats
    t_feats = inp[0]

student.model[-1].register_forward_pre_hook(save_student_feats)
teacher.model[-1].register_forward_pre_hook(save_teacher_feats)

# inside your overridden loss():
with torch.no_grad():
    _ = teacher(batch["img"])  # make sure this runs under the same AMP context as training if AMP is on

# if channels differ, you need adapters here before mse
kd_loss = sum(F.mse_loss(fs, ft.detach()) for fs, ft in zip(s_feats, t_feats))

If you want, paste how you’re wiring this into DetectionTrainer (i.e., where you set model.criterion / override DetectionModel.loss()), and I can point to the most “Ultralytics-native” place to attach the teacher + hooks without fighting the trainer internals.

2 Likes

Thanks for the detailed review and concrete suggestions — this is very helpful.

You’re absolutely right about caching the Detect inputs by reference and the in-place modification issue, as well as the channel mismatch between YOLO11n and YOLO11m. Using a forward pre-hook and attaching KD logic via an overridden loss() is much cleaner and aligns better with Ultralytics’ design.

its my code:

class KDDetectionTrainer(DetectionTrainer):
    def __init__(self, cfg=None, overrides=None, distiller='mgd', loss_weight=1.0, _callbacks=None,teacher=None,student=None):
        super().__init__(cfg=cfg, overrides=overrides, _callbacks=_callbacks)
        # 加载教师模型
        # self.teacher = YOLO('yolo11x.pt')  # 假设你有一个预训练的YOLOv11x模型
        self.model=student.model
        self.teacher=teacher.model
        self.teacher.eval()  # 将教师模型设置为评估模式
        for param in self.teacher.parameters():
            param.requires_grad = False  # 教师模型不需要梯度
        self.model.loss = KDDetectionLoss(
            model=self.model,
            teacher=self.teacher,
            kd_weight=loss_weight
        )
class KDDetectionLoss(v8DetectionLoss):

    def __init__(
        self,
        model,
        teacher,
        tal_topk=10,
        kd_weight=1.0,
    ):
        super().__init__(model, tal_topk)

        self.teacher = teacher
        self.kd_weight = kd_weight

        # Detect head
        self.student_detect = model.model[-1]
        self.teacher_detect = teacher.model[-1]

        # 开启特征缓存
        self.student_detect.kd_collect = True
        self.teacher_detect.kd_collect = True

        # 冻结 teacher
        for p in self.teacher.parameters():
            p.requires_grad = False
        self.teacher.eval()

    # ⚠️ 注意:签名必须和 v8DetectionLoss 完全一致
    def  __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
        """
        preds: student forward output
        batch: dict
        """

        # 1️⃣ 原始 YOLO loss(一字不改)
        det_loss, loss_items = super().__call__(preds, batch)

        # 2️⃣ forward teacher(正确方式)
        with torch.no_grad():
            _ = self.teacher(batch["img"])

        # 3️⃣ 取 P3/P4/P5
        s_feats = getattr(self.student_detect, "_kd_feat", None)
        t_feats = getattr(self.teacher_detect, "_kd_feat", None)

        if s_feats is None or t_feats is None:
            return det_loss, loss_items

        # 4️⃣ Feature KD
        kd_loss = 0.0
        for fs, ft in zip(s_feats, t_feats):
            kd_loss += F.mse_loss(fs, ft)

        total_loss = det_loss + self.kd_weight * kd_loss

        return total_loss, loss_items

Now , still don’t work,May be I should give up this way…

student_model =YOLO("yolo11n.yaml")
teacher_model=YOLO(r"C:\Users\AAAAA\Desktop\ultralytics\Custom_Distiller\best.pt")

args = dict(model=student_model,data="coco8.yaml", epochs=3)
trainer = KDDetectionTrainer(cfg=DEFAULT_CFG,student=student_model, teacher=teacher_model,overrides=args)

trainer.train()

optimizer: AdamW(lr=0.000119, momentum=0.9) with parameter groups 81 weight(decay=0.0), 88 weight(decay=0.0005), 87 bias(decay=0.0)
Image sizes 640 train, 640 val
Using 0 dataloader workers
Logging results to C:\Users\AAAAA\Desktop\ultralytics\runs\detect\train26
Starting training for 3 epochs...

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
Backend tkagg is interactive backend. Turning interactive mode on.
: 0% ──────────── 0/1  36.9s
Traceback (most recent call last):
  File "C:\Users\AAAAA\Desktop\ultralytics\ultralytics\engine\trainer.py", line 422, in _do_train
    loss, self.loss_items = self.model(batch)
                            ^^^^^^^^^^^^^^^^^
  File "D:\ultralytics\Lib\site-packages\torch\nn\modules\module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\ultralytics\Lib\site-packages\torch\nn\modules\module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\AAAAA\Desktop\ultralytics\ultralytics\nn\tasks.py", line 162, in forward
    return self.loss(x, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: KDDetectionLoss.__call__() missing 1 required positional argument: 'batch'

sorry to bother you again :sob @pderrenger : , but I don’t know how to override the def loss, Its inheritance relationship is too complicated..

1、I try to write the code like this:

student.model.loss = kd_loss.__get__(student.model, type(student.model))

but when it train ,it still execute the original def loss

2、I also try to inherit DetectModel like this,

class DetectionModelWithKD(DetectionModel):

    def __init__(self, cfg="yolo11n.yaml", ch=3, nc=None, verbose=True, teacher=None,kd_weight=1):

        super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)


        self.kd_weight=kd_weight

        self.teacher = teacher
        self.teacher.model.eval()
        for p in self.teacher.model.parameters():
            p.requires_grad = False

        self.s_feats, self.t_feats = None, None


        self.student_feats_hook = self.model[-1].register_forward_pre_hook(self.save_student_feats)
        self.teacher_feats_hook = self.teacher.model.model[-1].register_forward_pre_hook(self.save_teacher_feats)


        # self.original_loss = self.loss

    def save_student_feats(self, m, inp):

        self.s_feats = inp[0]  # [P3, P4, P5]

    def save_teacher_feats(self, m, inp):

        self.t_feats = inp[0]

    def loss(self, batch, preds=None):
        """
        Compute loss with Knowledge Distillation (KD).

        Args:
            batch (dict): Batch to compute loss on.
            preds (torch.Tensor | list[torch.Tensor], optional): Predictions.
        """
        if getattr(self, "criterion", None) is None:
            self.criterion = self.init_criterion()


        if preds is None:
            preds = self.forward(batch["img"])
        total_loss, loss_items = self.original_loss(batch, preds)

 
        with torch.no_grad():
            _ = self.teacher.model(batch["img"].to(device))

  
        kd_term = 0.0
        for fs, ft in zip(self.s_feats, self.t_feats):
            kd_term += F.mse_loss(fs, ft.detach())  


        total_loss = total_loss + self.kd_weight * kd_term


        loss_items = torch.cat([loss_items, kd_term.detach().unsqueeze(0)])


        self.s_feats, self.t_feats = None, None
        return total_loss, loss_items
    def original_loss(self, batch, preds=None):
        """
        Compute loss.

        Args:
            batch (dict): Batch to compute loss on.
            preds (torch.Tensor | list[torch.Tensor], optional): Predictions.

        if getattr(self, "criterion", None) is None:
            self.criterion = self.init_criterion()

        if preds is None:
            preds = self.forward(batch["img"])
        return self.criterion(preds, batch)

And modify the Train.py,still execute the original def loss, Firstly ,it execute ‘if self.Distill:’ branch ,Later ,it execute ’ else : 'Branch ,so model still be DetectionModel type,

    def get_model(self, cfg: str | None = None, weights: str | None = None, verbose: bool = True):
        """
        Return a YOLO detection model.

        Args:
            cfg (str, optional): Path to model configuration file.
            weights (str, optional): Path to model weights.
            verbose (bool): Whether to display model information.

        Returns:
            (DetectionModel): YOLO detection model.
        """
        if self.Distill:
            model = DetectionModelWithKD(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1,teacher=self.Student)
        else:
            model = DetectionModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)

        if weights:
            model.load(weights)
        return model
class BaseTrainer:

    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
   
        self.hub_session = overrides.pop("session", None)  # HUB
        self.Distill=overrides.pop("Distill",False)
        self.Student=overrides["model"]
1 Like

I’m really interested to build a minimal training/inference pipeline for deep understanding of the yolo models. I think it’d be really helpful for your use case as well since you want to do model distillation which involves forward passes for two models before doing a backward pass.

Moreover, I read the CrossKD paper which uses a novel technique for distilling object detecion models. From it, I understand that what you want to do is feature mimicking. Am I correct?

We can use the piece shared by @pderrenger as a starting point, since it’s much concise.

that is very what i thought :heart_hands: :face_holding_back_tears:

1 Like