Ultralytics YOLOv11 QAT

Hello everyone,

I’m trying to perform Quantization-Aware Training (QAT) on YOLOv11 using Ultralytics. However, I’m facing an issue where my modifications to the model are being reset before training starts.

Here’s the code I wrote for it:

from ultralytics import YOLO
import torch
from torch.nn import Module

DEVICE = 0 if torch.cuda.is_available() else 'cpu'

class QAT(Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.quant = torch.quantization.QuantStub()
        self.de_quant = torch.quantization.DeQuantStub()
        self.yaml = model.yaml

    def forward(self, x):
        x = self.quant(x)
        x = self.model(x)

        for i in range(len(x)):
            x[i] = self.de_quant(x[i])
        return x

model = YOLO('yolo11n.pt')

qa_model = QAT(model.model)
qa_model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
torch.quantization.prepare_qat(qa_model, inplace=True)
model.model = qa_model    

model.train(
    data='dataset/data.yaml',
    epochs=100
)

It seems that Ultralytics resets the model to its default state before training starts, discarding my modifications.

To debug this, I added a callback:

def print_yolo_before_start(trainer):
    print(trainer)

model.add_callback('on_train_start', print_yolo_before_start)

But when this runs, it prints the original YOLO model instead of the QAT-wrapped version, which differs in Conv layers.

How can I ensure my QAT modifications persist during training?

You can check this

Thanks for you response.

Is it possible to call QuantStub() and DeQuantStub() on input image right before the model forward pass and after it with callbacks in order to eliminate the need for customizing the ultralytics source code?

Because now the new class QAT does not contain the necessary properties and it’s giving me the following error:

Traceback (most recent call last):
  File "/home/srezas/Programming/projects/ariapa/qat.py", line 159, in <module>
    train_model(
  File "/home/srezas/Programming/projects/ariapa/qat.py", line 96, in train_model
    result = model.train(
  File "/home/srezas/Programming/projects/ariapa/venv/lib/python3.10/site-packages/ultralytics/engine/model.py", line 792, in train
    self.trainer.train()
  File "/home/srezas/Programming/projects/ariapa/venv/lib/python3.10/site-packages/ultralytics/engine/trainer.py", line 211, in train
    self._do_train(world_size)
  File "/home/srezas/Programming/projects/ariapa/venv/lib/python3.10/site-packages/ultralytics/engine/trainer.py", line 327, in _do_train
    self._setup_train(world_size)
  File "/home/srezas/Programming/projects/ariapa/venv/lib/python3.10/site-packages/ultralytics/engine/trainer.py", line 291, in _setup_train
    self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=LOCAL_RANK, mode="train")
  File "/home/srezas/Programming/projects/ariapa/venv/lib/python3.10/site-packages/ultralytics/models/yolo/detect/train.py", line 82, in get_dataloader
    dataset = self.build_dataset(dataset_path, mode, batch_size)
  File "/home/srezas/Programming/projects/ariapa/venv/lib/python3.10/site-packages/ultralytics/models/yolo/detect/train.py", line 64, in build_dataset
    gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
  File "/home/srezas/Programming/projects/ariapa/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1928, in __getattr__
    raise AttributeError(
AttributeError: 'QAT' object has no attribute 'stride'

What if you updated your QAT class to include a stride attribute? Perhaps try initializing QAT using QAT(model.model.model) as iirc, the third one is the collection of PyTorch layers. I haven’t tried anything with the torch.quantization module, but I suspect that it might take a bit of an in depth understanding of how it operates to be able to apply it correctly to Ultralytics YOLO.

I’ve made several changes to support QAT but am encountering an error during the final INT8 conversion step.

  1. I’ve added activation: nn.ReLU(0.1) to yolo11.yaml in order to change the default activation to ReLU.
  2. Pre-Training Preparation (in ultralytics/engine/trainer.py):

Added this before the training loop in BaseTrainer._do_train():

self.model.eval()
for m in self.model.modules():
    if hasattr(m, 'bn') and hasattr(m, 'conv') and \
        (hasattr(m, 'act')) and isinstance(m.act, nn.ReLU):
        print('fuse conv+bn+act')
        torch.ao.quantization.fuse_modules(
            m,
            [["conv", "bn", 'act']],
            inplace=True
        )
    elif hasattr(m, 'bn') and hasattr(m, 'conv'):
        print('fuse conv+bn')
        torch.ao.quantization.fuse_modules(
            m,
            [["conv", "bn"]],
            inplace=True
        )
self.model.qconfig = torch.ao.quantization.get_default_qat_qconfig()
torch.ao.quantization.prepare_qat(self.model.train(), inplace=True)
  1. EMA Updates Adjustment (in ultralytics/utils/torch_utils.py):

Modified ModelEMA.update() to skip BN layers (since they’re fused):

for k, v in self.ema.state_dict().items():
    if 'bn' not in k and v.dtype.is_floating_point:
        v *= d
        v += (1 - d) * msd[k].detach()
  1. Post-Training Conversion (failing):

At the end of _do_train() before self._clear_memory() I added:

self.model.eval()
model_int8 = torch.ao.quantization.convert(self.model)  # Error occurs here
torch.save(model_int8, 'yolov11_int8.pt')

And I’m getting the following error in torch.ao.quantization.convert() function:

Unsupported qscheme: per_channel_affine
  File "/home/srezas/projects/QAT/ultralytics/engine/trainer.py", line 496, in _do_train
    model_int8 = torch.ao.quantization.convert(self.model)
  File "/home/srezas/projects/QAT/ultralytics/engine/trainer.py", line 212, in train
    self._do_train(world_size)
  File "/home/srezas/projects/QAT/ultralytics/engine/model.py", line 791, in train
    self.trainer.train()
  File "/home/srezas/projects/QAT/qat.py", line 23, in train_model
    result = model.train(
  File "/home/srezas/projects/QAT/qat.py", line 87, in <module>
    train_model(
RuntimeError: Unsupported qscheme: per_channel_affine

Please let me know if you need any other infomration.

Also here’s the model architecture after prepare_qat() function.

Before going further on this, it would be helpful to understand why you want to do QAT vs PTQ? Have you tried using PTQ at all? If you haven’t, I would strongly recommend you test out using PTQ to see if you can achieve reasonable results before spending time on figuring out how to make QAT work.

Standard PTQ using export function as bellow drops mAP50-95 from 0.77 to 0.

model.export(format='engine', dynamic=True, simplify=True, int8=True, data='dataset/dataset.yaml')

That’s a very exceptional drop in performance. How many images are in the val split for your dataset? Additionally, have you tried without simplify=True? I can’t recall if that introduces issues with PTQ for TensorRT, but I suspect it could.

There’s about 1500 images in my validation set. I’ve also tried int8 quantization with simplify=False and the results are the same.

1 Like

Thanks for the detailed update and sharing your code modifications for attempting QAT. Integrating QAT directly into the training loop requires careful adjustments to the underlying code, as you’ve experienced.

The RuntimeError: Unsupported qscheme: per_channel_affine during the torch.ao.quantization.convert() step typically indicates that the chosen quantization configuration (qconfig) or the conversion process encountered a layer or parameter quantization scheme it doesn’t support for the target backend. This can sometimes happen with complex architectures or specific fused modules.

You might need to investigate which specific module within self.model is causing this incompatibility during the conversion phase or experiment with different qconfig settings if the default one isn’t suitable for all parts of the fused model.

Achieving good results with QAT often involves deep integration and tuning specific to the model architecture and the quantization toolkit, which goes beyond the standard features currently offered. The significant mAP drop you observed with standard INT8 PTQ via model.export() highlights the difficulty, although ensuring a sufficiently representative calibration dataset (as you did with data='dataset/dataset.yaml') is crucial for PTQ performance.