I’m trying to combine the weights of multiple yolov8n models. It does not seem to be currently supported within the ultralytics package, so I’m wondering if there is a clean way to do this manually?
I have tried loading the models with: #load trained models as state_dicts
modelA = torch.load('yolov8n-1.pt', map_location=torch.device('cpu'))['model'].state_dict()
modelB = torch.load('yolov8n-2.pt', map_location=torch.device('cpu'))['model'].state_dict()
#average trained model weights modelC = ... #---some method of averaging weights---
#initialize final model model = YOLO('yolov8n.yaml')
#load torch model into yolov8n model model.load(torchmodel)
Unfortunately this does not seem to work as expected. I have tried this method without altering weights by simply stripping and reloading a state_dict and it appears that the weights end up getting shuffled around during the process.
Has anyone encountered this or come across a solution for editing weights in yolov8 models?
If anyone else stumbles upon this issue, I figured out my mistake.
When loading a state_dict back into a torch model, you must use: torchmodel['model'].load_state_dict(state_dict_variable)
For a full implementation on a yolov8 model, see below:
#create empty yolov8 model
model = YOLO('yolov8n.yaml')
#load weights from desired model as state_dict
torchmodel_sd = torch.load('yolov8n.pt')['model'].state_dict()
'''perform some operations on weights'''
#load a dummy torch model with same architecture as target model
torchmodel = torch.load('yolov8n.pt')
#load modified weights into torch model
torchmodel['model'].load_state_dict(torchmodel_sd)
#load torch model into yolov8 model
model.load(weights=torchmodel)
Alright, so it turns out (unsurprisingly in hindsight) that torch is not necessary for this, since yolov8 is based on PyTorch .
Simply, you can strip, edit, and reload weights from an Ultralytics model as follows:
mymodel = YOLO('path/to/model')
weights = mymodel.model.state_dict()
'''some operation on weights'''
mymodel.reset_weights() #probably not required, but good for troubleshooting
mymodel.model.load_state_dict(weights)