mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-24 14:44:21 +08:00
Default classify training to pretrained=True
(#3239)
This commit is contained in:
parent
e78fb683f4
commit
15c90bd404
@ -19,7 +19,7 @@ workers: 8 # (int) number of worker threads for data loading (per RANK if DDP)
|
|||||||
project: # (str, optional) project name
|
project: # (str, optional) project name
|
||||||
name: # (str, optional) experiment name, results saved to 'project/name' directory
|
name: # (str, optional) experiment name, results saved to 'project/name' directory
|
||||||
exist_ok: False # (bool) whether to overwrite existing experiment
|
exist_ok: False # (bool) whether to overwrite existing experiment
|
||||||
pretrained: False # (bool) whether to use a pretrained model
|
pretrained: True # (bool) whether to use a pretrained model
|
||||||
optimizer: auto # (str) optimizer to use, choices=[SGD, Adam, Adamax, AdamW, NAdam, RAdam, RMSProp, auto]
|
optimizer: auto # (str) optimizer to use, choices=[SGD, Adam, Adamax, AdamW, NAdam, RAdam, RMSProp, auto]
|
||||||
verbose: True # (bool) whether to print verbose output
|
verbose: True # (bool) whether to print verbose output
|
||||||
seed: 0 # (int) random seed for reproducibility
|
seed: 0 # (int) random seed for reproducibility
|
||||||
|
@ -33,9 +33,8 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
if weights:
|
if weights:
|
||||||
model.load(weights)
|
model.load(weights)
|
||||||
|
|
||||||
pretrained = self.args.pretrained
|
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
if not pretrained and hasattr(m, 'reset_parameters'):
|
if not self.args.pretrained and hasattr(m, 'reset_parameters'):
|
||||||
m.reset_parameters()
|
m.reset_parameters()
|
||||||
if isinstance(m, torch.nn.Dropout) and self.args.dropout:
|
if isinstance(m, torch.nn.Dropout) and self.args.dropout:
|
||||||
m.p = self.args.dropout # set dropout
|
m.p = self.args.dropout # set dropout
|
||||||
@ -61,8 +60,7 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
elif model.endswith('.yaml'):
|
elif model.endswith('.yaml'):
|
||||||
self.model = self.get_model(cfg=model)
|
self.model = self.get_model(cfg=model)
|
||||||
elif model in torchvision.models.__dict__:
|
elif model in torchvision.models.__dict__:
|
||||||
pretrained = True
|
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if self.args.pretrained else None)
|
||||||
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
|
|
||||||
else:
|
else:
|
||||||
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
|
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
|
||||||
ClassificationModel.reshape_outputs(self.model, self.data['nc'])
|
ClassificationModel.reshape_outputs(self.model, self.data['nc'])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user