I’m trying to perform Quantization-Aware Training (QAT) on YOLOv11 using Ultralytics. However, I’m facing an issue where my modifications to the model are being reset before training starts.
Here’s the code I wrote for it:
from ultralytics import YOLO
import torch
from torch.nn import Module
DEVICE = 0 if torch.cuda.is_available() else 'cpu'
class QAT(Module):
def __init__(self, model):
super().__init__()
self.model = model
self.quant = torch.quantization.QuantStub()
self.de_quant = torch.quantization.DeQuantStub()
self.yaml = model.yaml
def forward(self, x):
x = self.quant(x)
x = self.model(x)
for i in range(len(x)):
x[i] = self.de_quant(x[i])
return x
model = YOLO('yolo11n.pt')
qa_model = QAT(model.model)
qa_model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
torch.quantization.prepare_qat(qa_model, inplace=True)
model.model = qa_model
model.train(
data='dataset/data.yaml',
epochs=100
)
It seems that Ultralytics resets the model to its default state before training starts, discarding my modifications.
Is it possible to call QuantStub() and DeQuantStub() on input image right before the model forward pass and after it with callbacks in order to eliminate the need for customizing the ultralytics source code?
Because now the new class QAT does not contain the necessary properties and it’s giving me the following error:
Traceback (most recent call last):
File "/home/srezas/Programming/projects/ariapa/qat.py", line 159, in <module>
train_model(
File "/home/srezas/Programming/projects/ariapa/qat.py", line 96, in train_model
result = model.train(
File "/home/srezas/Programming/projects/ariapa/venv/lib/python3.10/site-packages/ultralytics/engine/model.py", line 792, in train
self.trainer.train()
File "/home/srezas/Programming/projects/ariapa/venv/lib/python3.10/site-packages/ultralytics/engine/trainer.py", line 211, in train
self._do_train(world_size)
File "/home/srezas/Programming/projects/ariapa/venv/lib/python3.10/site-packages/ultralytics/engine/trainer.py", line 327, in _do_train
self._setup_train(world_size)
File "/home/srezas/Programming/projects/ariapa/venv/lib/python3.10/site-packages/ultralytics/engine/trainer.py", line 291, in _setup_train
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=LOCAL_RANK, mode="train")
File "/home/srezas/Programming/projects/ariapa/venv/lib/python3.10/site-packages/ultralytics/models/yolo/detect/train.py", line 82, in get_dataloader
dataset = self.build_dataset(dataset_path, mode, batch_size)
File "/home/srezas/Programming/projects/ariapa/venv/lib/python3.10/site-packages/ultralytics/models/yolo/detect/train.py", line 64, in build_dataset
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
File "/home/srezas/Programming/projects/ariapa/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1928, in __getattr__
raise AttributeError(
AttributeError: 'QAT' object has no attribute 'stride'
What if you updated your QAT class to include a stride attribute? Perhaps try initializing QAT using QAT(model.model.model) as iirc, the third one is the collection of PyTorch layers. I haven’t tried anything with the torch.quantization module, but I suspect that it might take a bit of an in depth understanding of how it operates to be able to apply it correctly to Ultralytics YOLO.