YOLO Custom Loss Function (Big Practical Integration Question)

Hi,

I am working on a custom detection loss for YOLOv8 to handle a very specific positional-bias issue in my dataset, and I would like feedback from people who know the YOLOv8 loss / assignment internals.

The key thing I want to validate is whether the way I:

  • map my external GT “pair” metadata into the loss coordinate space, and

  • pick the positive anchors and apply a teacher–student consistency penalty

is logically correct and compatible with YOLOv8

I will show the exact code block at the end and highlight the questions


1. Data situation and pair metadata - quick explanation of problem

In my dataset, some images can have the same physical defect appear twice in the image, vertically shifted by a fixed amount in y. Visually, those two boxes are the “same defect” with the same local background, just in different vertical positions.

In addition, I maintain a separate JSON file (PAIRS_GT_) with pair metadata keyed by image basename:

pairs_gt = {}

if os.path.isfile(PAIRS_GT_PATH):
    with open(PAIRS_GT_PATH, "r", encoding="utf-8") as f:
        _pg = json.load(f)
    # normalize to list of [ [c,x,y,w,h], [c,x,y,w,h] ]
    for k, v in _pg.items():
        norm = []
        for pair in v:
            a, b = pair
            a = list(a); b = list(b)
            norm.append([a, b])
        pairs_gt[k] = norm

These pairs are always same-class pairs (duplicate view of the same defect).


2. Base detection loss setup

I define a custom v8DetectionLoss that is structurally very close to Ultralytics’ YOLOv8 detection loss and replace the original one.


3. Pair-based teacher–student consistency logic

Now the main part I need reviewed.

Goal: for each pair of GTs representing the same physical defect at two vertical positions, enforce that the model’s predictions (on their respective positive anchors) are consistent:

  • If one anchor is very confident on some class, the other anchor should not be much worse.

  • If the strong anchor is confident on class k and the weak anchor is peaking on a different class j, I encourage the weak anchor to move toward k.

To do this, I:

  1. Use batch["im_file"] to get basenames and look up pairs_gt[basename].

  2. Use batch["ori_shape"] and batch["ratio_pad"] to map original normalized xywh pairs into the letterboxed xyxy space used by gt_bboxes.

  3. For each pair, I find the corresponding GT indices (gi1, gi2) by IoU with gt_bboxes[b_idx].

  4. Using fg_mask and target_scores, I find the positive anchors supervising those GTs.

  5. On those two anchors, I compute top-3 class confidences and define a teacher–student relation.

  6. I compute a piecewise-smooth penalty on the confidence gap and optionally a mismatch penalty when teacher and student top-1 classes differ.

  7. I average all pair penalties and add them into the class loss.

Here is the exact block:

pair_loss_terms = []
im_files = batch.get('im_file', [])

# compact block: standard letterbox (using ori_shape + ratio_pad)
if isinstance(im_files, (list, tuple)) and im_files and pairs_gt:
    device = pred_scores.device
    dtype  = pred_scores.dtype

    pair_cap_t    = torch.as_tensor(pair_cap, device=device, dtype=dtype)
    lambda_pair_t = torch.as_tensor(lambda_pair, device=device, dtype=dtype)

    # network input H, W
    H = float(imgsz[0].item()); W = float(imgsz[1].item())

    ori_shape = batch.get('ori_shape', None)  # [(h0,w0), ...] or None

    for b_idx, pth in enumerate(im_files):
        import os
        base = os.path.splitext(os.path.basename(pth))[0]

        img_pairs = pairs_gt.get(base, [])
        if not img_pairs:
            continue
        if base not in bnames:  # debug filter; can be removed
            continue

        # GTs (LETTERBOXED) and classes
        gtb = gt_bboxes[b_idx].detach().float()      # (G, 4)
        gtc = gt_labels[b_idx, :, 0].detach().long() # (G,)

        # positive anchors
        pos = torch.nonzero(fg_mask[b_idx], as_tuple=False).view(-1)
        if pos.numel() == 0:
            continue

        # original size
        if ori_shape is not None and len(ori_shape) > b_idx:
            h0, w0 = map(float, ori_shape[b_idx][:2])
        else:
            h0, w0 = H, W

        rp = batch.get('ratio_pad', None)
        if rp is not None and len(rp) > b_idx:
            r, (padw, padh) = rp[b_idx]
        else:
            # fallback: canonical letterbox
            r = min(H / h0, W / w0)
            padw = (W - w0 * r) / 2.0
            padh = (H - h0 * r) / 2.0

        pos_cands = pos  # candidate anchors

        for pi, pair in enumerate(img_pairs):
            (c1, x1, y1, w1, h1), (c2, x2, y2, w2, h2) = pair
            c1 = int(c1); c2 = int(c2)

            # ORIGINAL (norm) -> ORIG xyxy -> LETTERBOXED xyxy
            p1 = apply_letterbox(
                xywhn_to_xyxy_orig(x1, y1, w1, h1, h0, w0),
                r, padw, padh
            )
            p2 = apply_letterbox(
                xywhn_to_xyxy_orig(x2, y2, w2, h2, h0, w0),
                r, padw, padh
            )

            # find GT indices by class + max IoU
            gi1 = gi2 = -1
            best1 = best2 = -1.0
            for gi in range(gtb.shape[0]):
                cls = int(gtc[gi])
                if cls == c1:
                    iou = iou_xyxy(gtb[gi].tolist(), p1)
                    if iou > best1:
                        best1, gi1 = iou, gi
                if cls == c2:
                    iou = iou_xyxy(gtb[gi].tolist(), p2)
                    if iou > best2:
                        best2, gi2 = iou, gi

            # require good match to existing GTs
            if gi1 < 0 or gi2 < 0 or best1 < 0.85 or best2 < 0.85:
                continue

            # filter positive anchors supervising class c1
            pos_c = pos_cands[target_scores[b_idx, pos_cands, c1] > 0]
            if pos_c.numel() == 0:
                continue
            tb_pos = target_bboxes[b_idx, pos_c]  # (K, 4)

            # best positive anchor for each GT: max IoU in tb_pos
            def _best_anchor_for(gt_box):
                ious = [iou_xyxy(tb_pos[k].tolist(), gt_box.tolist())
                        for k in range(tb_pos.shape[0])]
                k = int(torch.tensor(ious).argmax().item())
                return int(pos_c[k].item())

            a1 = _best_anchor_for(gtb[gi1])
            a2 = _best_anchor_for(gtb[gi2])

            # logits and probs
            logits_a1 = pred_scores[b_idx, a1, :]   # [C]
            logits_a2 = pred_scores[b_idx, a2, :]   # [C]
            probs_a1  = torch.sigmoid(logits_a1)
            probs_a2  = torch.sigmoid(logits_a2)

            topk_vals,  topk_ids  = torch.topk(probs_a1,  k=3)
            topk_vals2, topk_ids2 = torch.topk(probs_a2, k=3)

            top1_conf1 = topk_vals[0]
            top1_conf2 = topk_vals2[0]
            top1_c1    = topk_ids[0].item()
            top1_c2    = topk_ids2[0].item()

            # teacher–student choice
            if float(top1_conf1) >= float(top1_conf2):
                teacher_conf = top1_conf1.detach()  # no grad
                student_conf = top1_conf2           # gets grad
                teacher_cls  = top1_c1
                student_cls  = top1_c2
                student_probs = probs_a2
            else:
                teacher_conf = top1_conf2.detach()
                student_conf = top1_conf1
                teacher_cls  = top1_c2
                student_cls  = top1_c1
                student_probs = probs_a1

            same_class    = (teacher_cls == student_cls)
            trust_teacher = (float(teacher_conf) >= tau_high)

            raw_gap = teacher_conf - student_conf - margin
            raw_gap = torch.clamp(raw_gap, min=0.0)

            pair_cap_t = torch.tensor(pair_cap, device=pred_scores.device, dtype=raw_gap.dtype)
            delta = pair_cap_t * 0.7
            abs_diff = raw_gap.abs()
            L_gap = torch.where(
                abs_diff < delta,
                0.5 * abs_diff**2,
                delta * (abs_diff - 0.5 * delta),
            )
            L_gap = L_gap.view(1)

            if trust_teacher and same_class and raw_gap > 0:
                # strong, directional consistency term on the confidence gap
                L_pair_per = lambda_pair_t * L_gap

            elif trust_teacher and not same_class:
                # teacher strong but student top-1 is another class:
                # push student probability towards teacher class
                p_student_teacher = student_probs[teacher_cls]
                diff_cls = torch.clamp(teacher_conf - p_student_teacher, min=0.0)
                L_cls = diff_cls * diff_cls
                L_pair_per = lambda_pair_t * mismatch_k * L_cls.view(1)

            else:
                # weak teacher or no gap: do nothing
                L_pair_per = torch.zeros(1, device=pred_scores.device)

            pair_loss_terms.append(L_pair_per)

# final cls loss combination
target_scores_sum = max(target_scores.sum(), 1)
loss_cls = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum

if pair_loss_terms:
    pair_loss_mean = torch.stack(pair_loss_terms).mean()
    loss_cls = loss_cls + pair_loss_mean * pair_gain
    pair_loss_terms.clear()

loss_cls is then combined with box and DFL losses as usual.


4. What I want to confirm

I would appreciate validation or corrections on the following points:

  1. Coordinate space and mapping

    • I start from original normalized pairs x,y,w,h (in pairs_gt), then:

      • Convert to original xyxy with xywhn_to_xyxy_orig.

      • Map to letterboxed xyxy with apply_letterbox using the same r, padw, padh derived from ori_shape and ratio_pad.

    • I then match to gt_bboxes[b_idx], which is the output of preprocess after scaling to the model input space.
      Is this the correct way to ensure that the pair boxes and gt_bboxes live in the same coordinate system as used by the assigner?

  2. Use of fg_mask and target_scores to select positive anchors

    • I take pos = nonzero(fg_mask[b_idx]) to get anchors with at least one positive class.

    • Then I further filter to anchors supervising c1 via target_scores[b_idx, pos_cands, c1] > 0.

    • I then use IoU between tb_pos (those anchors’ target boxes) and the GT boxes to pick the “best” anchor per GT.
      Is this consistent with how YOLOv8 expects you to identify the positive anchors associated with a given GT after TaskAlignedAssigner?

  3. Teacher–student selection and gradient flow

    • I select the anchor with higher top-1 confidence as the teacher and detach its confidence.

    • The student anchor’s probabilities (student_probs) stay attached to the graph and receive gradients from L_gap and/or L_cls.
      Do you see any issue with this asymmetric setup inside the detection loss (in terms of stability, interaction with BCE, or how the gradients propagate through pred_scores)?

  4. Piecewise-smooth gap penalty and cap

    • I use a Huber-like penalty on the confidence gap (L_gap) with a cap pair_cap.

    • Then I scale by lambda_pair_t and finally by pair_gain when adding to loss_cls.
      Is adding this extra scalar penalty directly into the classification loss (loss_cls) a reasonable way to integrate such a regularizer, or would you recommend combining it differently (e.g., separate term, different normalization)?

  5. Overall pattern: external metadata + assigner outputs
    Conceptually, I am:

    • Keeping YOLO label files standard.

    • Using external JSON metadata (pairs_gt[basename]) only inside the loss.

    • Using assigner outputs (target_bboxes, target_scores, fg_mask) to identify the relevant positive anchors.
      Is this pattern aligned with how you would recommend injecting a custom, GT-pair-based regularization term into YOLOv8, or is there a cleaner / more robust hook that I am missing?

Any comments on conceptual mistakes, edge cases I am ignoring, or better ways to tie this into the existing loss machinery would be very helpful.

I used this print for debugging and only tracked specific images and the results matched the ground truth so I believe I am doing it right but it is so confusing I wanted to make sure.

                    print(

                            f"\\n\[PAIR DEBUG\] IMAGE {base} | pair #{pi} \\n"

                            f"same class = {c1} and {c2}"

                            f"  Anchor {a1}: top3 {\[(int(c), float(v)) for c, v in zip(topk_ids, topk_vals)\]}\\n"

                            f"  Anchor {a2}: top3 {\[(int(c), float(v)) for c, v in zip(topk_ids2, topk_vals2)\]}\\n"

                            f"  top1conf1={float(top1_conf1):.4f}  top1conf2={float(top1_conf2):.4f}\\n"

                            f"  top1c1={top1_c1}  top1c2={top1_c2}\\n"

                            f"  abs_diff={float(abs_diff):.4f}  same_class={same_class}   L_pair_per = {L_pair_per}",

                            file=sys.stderr,

                            flush=True

                        )

Before you do any of this extra work, have you tried training a model on the data as-is?

If you’re concern is around the double detection, if you know that the “real” defect is at position Y, and that the projected defect is at Y +/- d, then you should be able to add a simple filter to your detections if the orientation of the images/object is consistent, and there’s no need for any custom loss or training workflow.

I’ve had something similar to this issue when inspecting defects in glass. The imaging was able to capture defects on both the front and back side of the glass object. We had success labeling with defect-front and defect-back to separate them, but in the end it was easier to just ignore the region where defects would show in the back. Your situation sounds even easier, as the defects we saw were not the same ones, and so there was no alignment with front and back defects. Since yours are the same and always aligned, it’s much easier to know which ones to reject.
You’ll save yourself a lot of time now and in the future by taking the simpler route. I highly recommend filtering post-detection instead of trying to engineer a model-based solution.

Yes, and then I created some metrics to see the differences in pair predictions (confidence, class, one is detected one is not). Now my objective is to make these metrics as equal as possible for every image (minimize differences in pair predictions). I tried specific data augmentations, fine-tuning with pair only images (some images have pairs others don’t), and now I’m trying this loss function strategy.

My objective is more towards academic research so simply removing the second defect is not really the objective but thanks for the suggestion.

I just wanted to make sure everything I was doing is correct cause there seems to be very little documentation for altering the default V8 loss function.

Thank you!

If I understand correctly, the issue is that the “virtual” defect is being classified differently than the “real” defect in the image, and you want the detection to be the same for both?

Showing an example image with the detections and ground truth labels will go a long way in helping to understand the problem you’re facing and attempting to resolve. I’m struggling at the moment to grasp the scope of the issue, which makes it difficult to provide support.

Do all the objects you want to detect, including the pairs, have labels? Or just one of the objects have label? I don’t understand the usefulness of using the confidence of an anchor as teacher for another, when ground truth provides better and more foolproof signal, while the anchor’s prediction may or may not be correct.

Also, the mapping from loaded pair coordinate would not work unless you take into account the transforms applied by augmentations. The GT bboxes that are available at this stage have been transformed by augmentations. But the exact transformation matrix is not kept. So this would only work if you disable augmentations completely.

The help need is more in confirming the code I sent seems fine but the images are the lateral view of a rotating object (it rotates slightly more than 360 degrees so some defects appear twice if they are in certain positions - one at the top and the other 0.70*h later). Something like this (i draw it in paint cuz the dataset is private, the red circle is the defect that appears twice the rest are just random defects that only appear once):

Yes, they do.

Exactly, thank you so much for the answer. I have to disable all position altering augmentations for this you are correct (which decreases the overall metrics a bit unfortunately, but I dont really have a better solution, what I am trying right now is to only fine-tune the model with my custom loss function, after training the model with run-time augmentations on). Maybe I should compare them to the gt_box that might be a better strategy you are right.

But does the logic seem correct? Thank you so much.

I do understand you’re asking for help with the code, and you should read this post to understand why I’m not talking about the code.

1 Like

True, I understand what you mean, thank you

You’re on the right track conceptually, and your overall “external metadata + assigner outputs + extra regularizer inside the loss” pattern is exactly how I’d plug this kind of idea into YOLOv8 (and it carries over to YOLO11 as well).

On coordinate spaces: gt_bboxes in v8DetectionLoss are in the model-input, letterboxed xyxy space after all augmentations, as you saw from preprocess() multiplying by imgsz and calling xywh2xyxy in the loss utilities implementation. Your mapping from normalized original pairs → original xyxy → letterboxed xyxy is correct only if you disable any spatial/geom augmentations (mosaic, random affine, flips, etc.). With those disabled, your xywhn_to_xyxy_orig + apply_letterbox(r, padw, padh) should land in the same space as gt_bboxes, as you intended.

For selecting anchors: using fg_mask[b_idx] to get positives and then filtering by target_scores[b_idx, pos_cands, c1] > 0 is consistent with how TaskAlignedAssigner is used in v8DetectionLoss. However, instead of re-discovering the mapping via IoU over target_bboxes, I would strongly recommend exposing the fifth output from the assigner (what v8SegmentationLoss and v8PoseLoss call target_gt_idx) in your custom detection loss. That gives you, for each positive anchor, the exact GT index chosen by the assigner and removes the need for a second IoU-based matching step, which can diverge slightly from what the assigner did.

The teacher–student part as you’ve implemented it (higher-confidence anchor as teacher, detached, lower one as student getting gradients through student_probs) is a standard asymmetric setup and is fine from a gradient-flow perspective. The only practical concern is stability and over-regularization early in training, so I’d keep lambda_pair, pair_cap, and pair_gain small at first or enable this term only after a short warmup, then watch per-class metrics and loss curves.

Adding the pair penalty directly into loss_cls is reasonable: you’re effectively adding a regularization term that still backprops through pred_scores. Since you already average over pairs, the main thing is to treat it like any other auxiliary loss and tune its overall weight so it’s on a similar scale as the BCE term (you can log loss_cls and your pair term separately to check they’re not orders of magnitude apart).

So in short: yes, the logic is compatible with the YOLOv8 loss/assignment internals, provided you either (1) freeze spatial augmentations as you already noted, or (2) move to using target_gt_idx from the assigner instead of reconstructing matches by IoU. For an academic exploration of positional bias, your approach is a reasonable and “clean” way to inject pairwise structure into the training signal.

1 Like

Thank you so much for the awesome and detailed response and suggestions!

Since I can’t really identify the pairs if they have been positionally altered and I so I turnoff automatic augmentations, what I am trying right now is training with augmentations for 30 epochs with the normal loss and then trying to fine-tune for 10 epochs with my new loss.

Do you think this strategy makes sense?

Another quick question: Your suggestion “exposing the fifth output from the assigner (what v8SegmentationLoss and v8PoseLoss call target_gt_idx" is for making the code as robust as possible (be 100% sure the correct GT box is chosen) and as clean as possible? Is that correct?