Ultralytics YOLOv11 QAT

Hello everyone,

I’m trying to perform Quantization-Aware Training (QAT) on YOLOv11 using Ultralytics. However, I’m facing an issue where my modifications to the model are being reset before training starts.

Here’s the code I wrote for it:

from ultralytics import YOLO
import torch
from torch.nn import Module

DEVICE = 0 if torch.cuda.is_available() else 'cpu'

class QAT(Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.quant = torch.quantization.QuantStub()
        self.de_quant = torch.quantization.DeQuantStub()
        self.yaml = model.yaml

    def forward(self, x):
        x = self.quant(x)
        x = self.model(x)

        for i in range(len(x)):
            x[i] = self.de_quant(x[i])
        return x

model = YOLO('yolo11n.pt')

qa_model = QAT(model.model)
qa_model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
torch.quantization.prepare_qat(qa_model, inplace=True)
model.model = qa_model    

model.train(
    data='dataset/data.yaml',
    epochs=100
)

It seems that Ultralytics resets the model to its default state before training starts, discarding my modifications.

To debug this, I added a callback:

def print_yolo_before_start(trainer):
    print(trainer)

model.add_callback('on_train_start', print_yolo_before_start)

But when this runs, it prints the original YOLO model instead of the QAT-wrapped version, which differs in Conv layers.

How can I ensure my QAT modifications persist during training?

You can check this

Thanks for you response.

Is it possible to call QuantStub() and DeQuantStub() on input image right before the model forward pass and after it with callbacks in order to eliminate the need for customizing the ultralytics source code?

Because now the new class QAT does not contain the necessary properties and it’s giving me the following error:

Traceback (most recent call last):
  File "/home/srezas/Programming/projects/ariapa/qat.py", line 159, in <module>
    train_model(
  File "/home/srezas/Programming/projects/ariapa/qat.py", line 96, in train_model
    result = model.train(
  File "/home/srezas/Programming/projects/ariapa/venv/lib/python3.10/site-packages/ultralytics/engine/model.py", line 792, in train
    self.trainer.train()
  File "/home/srezas/Programming/projects/ariapa/venv/lib/python3.10/site-packages/ultralytics/engine/trainer.py", line 211, in train
    self._do_train(world_size)
  File "/home/srezas/Programming/projects/ariapa/venv/lib/python3.10/site-packages/ultralytics/engine/trainer.py", line 327, in _do_train
    self._setup_train(world_size)
  File "/home/srezas/Programming/projects/ariapa/venv/lib/python3.10/site-packages/ultralytics/engine/trainer.py", line 291, in _setup_train
    self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=LOCAL_RANK, mode="train")
  File "/home/srezas/Programming/projects/ariapa/venv/lib/python3.10/site-packages/ultralytics/models/yolo/detect/train.py", line 82, in get_dataloader
    dataset = self.build_dataset(dataset_path, mode, batch_size)
  File "/home/srezas/Programming/projects/ariapa/venv/lib/python3.10/site-packages/ultralytics/models/yolo/detect/train.py", line 64, in build_dataset
    gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
  File "/home/srezas/Programming/projects/ariapa/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1928, in __getattr__
    raise AttributeError(
AttributeError: 'QAT' object has no attribute 'stride'

What if you updated your QAT class to include a stride attribute? Perhaps try initializing QAT using QAT(model.model.model) as iirc, the third one is the collection of PyTorch layers. I haven’t tried anything with the torch.quantization module, but I suspect that it might take a bit of an in depth understanding of how it operates to be able to apply it correctly to Ultralytics YOLO.