Implementing Knowledge Distillation with YOLO11N Student and YOLO11M Teacher in Ultralytics Trainer

Hi Ultralytics community,

I am working on a lightweight model based on YOLO11N. I didn’t change the detect layer, but as expected, the accuracy dropped compared to the original model. To recover the lost performance, I plan to apply knowledge distillation (KD) using YOLO11M as the teacher.

I created a custom trainer by inheriting from DetectionTrainer:

from ultralytics.models.yolo.detect import DetectionTrainer

class KDDetectionTrainer(DetectionTrainer):
    def __init__(self, cfg=None, overrides=None, distiller='mgd', loss_weight=1.0, _callbacks=None, teacher=None, student=None):
        super().__init__(cfg=cfg, overrides=overrides, _callbacks=_callbacks)
        self.model = student
        self.teacher = teacher
        self.teacher.model.eval()
        for param in self.teacher.model.parameters():
            param.requires_grad = False

And I instantiated it like this:

from ultralytics import YOLO
from ultralytics.models.yolo.detect.KDDetectionTrainer import KDDetectionTrainer
from ultralytics.utils import DEFAULT_CFG

student_model = YOLO("yolo11n.yaml")
teacher_model = YOLO(r"C:\Users\Hunger\Desktop\ultralytics\Custom_Distiller\best.pt")

args = dict(data="coco8.yaml", epochs=3)
trainer = KDDetectionTrainer(cfg=DEFAULT_CFG, student=student_model, teacher=teacher_model, overrides=args)

trainer.train()

My main question is: how can I access the outputs of the student and teacher during training so I can compute a custom distillation loss? I want to combine the original YOLO loss with the KD loss, but I’m not sure what is the proper way to hook into the training loop of DetectionTrainer to get the intermediate outputs.

Also, if anyone has similar implementations of knowledge distillation with Ultralytics YOLO models, it would be really helpful if you could share them for reference.

Any guidance or example on extending DetectionTrainer for knowledge distillation would be greatly appreciated!

Thanks in advance!

You can override the loss method:

2 Likes

Thank U, is that easy.

@Toxite is this OK ?

class KDDetectionLoss(v8DetectionLoss):

    def __init__(
        self,
        model,
        teacher,
        tal_topk=10,
        kd_weight=1.0,
    ):
        super().__init__(model, tal_topk)

        self.model = model
        self.teacher = teacher
        self.kd_weight = kd_weight

        # 找 Detect
        self.student_detect = self._get_detect(model.model)
        self.teacher_detect = self._get_detect(teacher.model)

        # 让 Detect 缓存 P3/P4/P5
        self.student_detect.kd_collect = True
        self.teacher_detect.kd_collect = True

        # 冻结 teacher
        self.teacher.eval()
        for p in self.teacher.parameters():
            p.requires_grad = False

    @staticmethod
    def _get_detect(model):
        for m in model.modules():
            if getattr(m, "is_detect", False):
                return m
        raise RuntimeError("Detect module not found")

    def __call__(self, batch, preds):
        """
        batch: dict
        preds: student forward output
        """

        # 1️⃣ 原始 YOLO detection loss
        det_loss, loss_items = super().__call__(preds, batch)

        # 2️⃣ forward teacher(必须手动)
        with torch.no_grad():
            _ = self.teacher(batch["img"])

        # 3️⃣ 取 P3/P4/P5
        s_feats = getattr(self.student_detect, "_kd_feat", None)
        t_feats = getattr(self.teacher_detect, "_kd_feat", None)

        if s_feats is None or t_feats is None:
            return det_loss, loss_items

        # 4️⃣ Feature KD loss
        kd_loss = 0.0
        for fs, ft in zip(s_feats, t_feats):
            kd_loss += F.mse_loss(fs, ft.detach())

        total_loss = det_loss + self.kd_weight * kd_loss

        return total_loss, loss_items

Detect.forward

def forward(self, x: list[torch.Tensor]) -> list[torch.Tensor] | tuple:
    """Concatenate and return predicted bounding boxes and class probabilities."""
    if getattr(self, "kd_collect", False):
        self._kd_feat = x  # x = [P3, P4, P5]
    if self.end2end:
        return self.forward_end2end(x)

    for i in range(self.nl):
        x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
    if self.training:  # Training path
        return x
    y = self._inference(x)
    return y if self.export else (y, x)