Im using a custom dataloader to train yolo in order to ovveride how yolo reads files as i need to load them from a server rather than the os.
im using a yolo collate function
def yolo_collate_fn(batch):
imgs = torch.stack([sample['img'] for sample in batch], dim=0)
cls_list = [sample['cls'] for sample in batch]
bboxes_list = [sample['bboxes'] for sample in batch]
img_paths = [sample['img_path'] for sample in batch]
ori_shapes = [sample['ori_shape'] for sample in batch]
batch_idx = []
for i, cls in enumerate(cls_list):
batch_idx.append(torch.full((cls.shape[0],), i, dtype=torch.long))
batch_idx = torch.cat(batch_idx, dim=0) if batch_idx else torch.tensor([], dtype=torch.long)
return {
'img': imgs,
'cls': torch.cat(cls_list, dim=0) if cls_list else torch.tensor([], dtype=torch.float32),
'bboxes': torch.cat(bboxes_list, dim=0) if bboxes_list else torch.tensor([], dtype=torch.float32),
'batch_idx': batch_idx,
'img_path': img_paths,
'im_file': img_paths,
'ori_shape': ori_shapes,
}
my bboxes are length 8 to coorespond to xyxyxyxy format
tensor([0.7008, 0.9064, 0.7146, 0.9037, 0.7170, 0.9157, 0.7032, 0.9185])
im training via this manner
overrides = {
'model': "yolo11n-obb.pt",
'task': 'detect-obb',
'imgsz': 768,
'epochs': 1,
'batch': BATCH_SIZE,
'device': 0,
'workers': 0,
'verbose': True,
'data': 'data.yaml',
}
trainer = Custom_OBBTrainer(overrides=overrides)
trainer.train()
yet, when i train it seems like the loss function is expecting xyhwa format as i get this error
RuntimeError Traceback (most recent call last)
~/.conda/envs/py-cv-client-ipython/lib/python3.9/site-packages/ultralytics/utils/loss.py in __call__(self, preds, batch)
701 batch_idx = batch["batch_idx"].view(-1, 1)
--> 702 targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1)
703 rw, rh = targets[:, 4] * imgsz[0].item(), targets[:, 5] * imgsz[1].item()
RuntimeError: shape '[-1, 5]' is invalid for input of size 64
is it possible to train with a custom dataloader via xyxyxyxy format or should i swap all my data to 5 dim format?
thanks