fix yolo classify model loading error (#9196)

This commit is contained in:
Mo Li 2024-03-23 01:04:34 +08:00 committed by GitHub
parent 8617fcf32d
commit 292e028779
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -73,7 +73,7 @@ class ClassificationTrainer(BaseTrainer):
elif model in torchvision.models.__dict__: elif model in torchvision.models.__dict__:
self.model = torchvision.models.__dict__[model](weights="IMAGENET1K_V1" if self.args.pretrained else None) self.model = torchvision.models.__dict__[model](weights="IMAGENET1K_V1" if self.args.pretrained else None)
else: else:
FileNotFoundError(f"ERROR: model={model} not found locally or online. Please check model name.") raise FileNotFoundError(f"ERROR: model={model} not found locally or online. Please check model name.")
ClassificationModel.reshape_outputs(self.model, self.data["nc"]) ClassificationModel.reshape_outputs(self.model, self.data["nc"])
return ckpt return ckpt