mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
build_optimizer()
assign all parameters (#2855)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
441e67d330
commit
61fa5efe6d
@ -618,15 +618,19 @@ class BaseTrainer:
|
||||
Returns:
|
||||
optimizer (torch.optim.Optimizer): the built optimizer
|
||||
"""
|
||||
|
||||
g = [], [], [] # optimizer parameter groups
|
||||
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
|
||||
for v in model.modules():
|
||||
if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias (no decay)
|
||||
g[2].append(v.bias)
|
||||
if isinstance(v, bn): # weight (no decay)
|
||||
g[1].append(v.weight)
|
||||
elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
|
||||
g[0].append(v.weight)
|
||||
|
||||
for module_name, module in model.named_modules():
|
||||
for param_name, param in module.named_parameters(recurse=False):
|
||||
fullname = f'{module_name}.{param_name}' if module_name else param_name
|
||||
if 'bias' in fullname: # bias (no decay)
|
||||
g[2].append(param)
|
||||
elif isinstance(module, bn): # weight (no decay)
|
||||
g[1].append(param)
|
||||
else: # weight (with decay)
|
||||
g[0].append(param)
|
||||
|
||||
if name == 'Adam':
|
||||
optimizer = torch.optim.Adam(g[2], lr=lr, betas=(momentum, 0.999)) # adjust beta1 to momentum
|
||||
|
Loading…
x
Reference in New Issue
Block a user