mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-24 14:44:21 +08:00
Default classify training to pretrained=True
(#3239)
This commit is contained in:
parent
e78fb683f4
commit
15c90bd404
@ -19,7 +19,7 @@ workers: 8 # (int) number of worker threads for data loading (per RANK if DDP)
|
||||
project: # (str, optional) project name
|
||||
name: # (str, optional) experiment name, results saved to 'project/name' directory
|
||||
exist_ok: False # (bool) whether to overwrite existing experiment
|
||||
pretrained: False # (bool) whether to use a pretrained model
|
||||
pretrained: True # (bool) whether to use a pretrained model
|
||||
optimizer: auto # (str) optimizer to use, choices=[SGD, Adam, Adamax, AdamW, NAdam, RAdam, RMSProp, auto]
|
||||
verbose: True # (bool) whether to print verbose output
|
||||
seed: 0 # (int) random seed for reproducibility
|
||||
|
@ -33,9 +33,8 @@ class ClassificationTrainer(BaseTrainer):
|
||||
if weights:
|
||||
model.load(weights)
|
||||
|
||||
pretrained = self.args.pretrained
|
||||
for m in model.modules():
|
||||
if not pretrained and hasattr(m, 'reset_parameters'):
|
||||
if not self.args.pretrained and hasattr(m, 'reset_parameters'):
|
||||
m.reset_parameters()
|
||||
if isinstance(m, torch.nn.Dropout) and self.args.dropout:
|
||||
m.p = self.args.dropout # set dropout
|
||||
@ -61,8 +60,7 @@ class ClassificationTrainer(BaseTrainer):
|
||||
elif model.endswith('.yaml'):
|
||||
self.model = self.get_model(cfg=model)
|
||||
elif model in torchvision.models.__dict__:
|
||||
pretrained = True
|
||||
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
|
||||
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if self.args.pretrained else None)
|
||||
else:
|
||||
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
|
||||
ClassificationModel.reshape_outputs(self.model, self.data['nc'])
|
||||
|
Loading…
x
Reference in New Issue
Block a user