Hello and thank you for this help forum.
I try to train a yolov11 model with a certain subset of classes from the CityScapes dataset.
As my results so far are only moderate, I would like to try another approach and use a weighted data loader in order to have balanced classes during training.
I tried to follow these intructions to achieve this:
The approach seems to work on my local machine (which does not have GPU support) but when I use it on the data cluster of my research institution, it seems that the class YOLOWeightedDataset
is not used although I monkey patched it over dataset.YOLODataset
.
After a lot of debugging and banging my head against the keyboard I now think that somewhere in YOLO.train()
a new process is spawned that does not know about the monkey patch and therefore would not use it.
Is there another way to achieve a weight balanced training, e.g. by invoking my custom class YOLOWeightedDataset in a different way with the YOLO.train()
method?
Here is what I did so far:
from ultralytics import YOLO
import ultralytics.data as data
import ultralytics.data.dataset as dataset
import ultralytics.data.build as build
import numpy as np
class YOLOWeightedDataset(data.dataset.YOLODataset):
def __init__(self, *args, data=None, task="train", **kwargs):
"""
Initialize the WeightedDataset.
Args:
class_weights (list or numpy array): A list or array of weights corresponding to each class.
"""
super(YOLOWeightedDataset, self).__init__(*args, data=data, task=task, **kwargs)
self.train_mode = "train" in self.prefix
# You can also specify weights manually instead
self.count_instances()
class_weights = np.sum(self.counts) / self.counts
# Aggregation function
self.agg_func = np.mean
self.class_weights = np.array(class_weights)
self.weights = self.calculate_weights()
self.probabilities = self.calculate_probabilities()
... # rest of the code from https://y-t-g.github.io/tutorials/yolo-class-balancing/
dataset.YOLODataset = YOLOWeightedDataset
model = YOLO("yolo11n.pt")
model.train(
data='data.yaml',
device=[0,1],
batch=128,
imgsz=640,
save_period=10,
project="runs/detect/nano_balanced"
)
It also might be worth mentioning that I run this code in a Jupyter Notebook on my research institution’s cluster.
Thank you very much in advance for your help!