mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
ultralytics 8.0.154
add freeze
training argument (#4329)
This commit is contained in:
parent
9f6d48d3cf
commit
d47718c367
@ -141,7 +141,7 @@ Remember that checkpoints are saved at the end of every epoch by default, or at
|
|||||||
Training settings for YOLO models refer to the various hyperparameters and configurations used to train the model on a dataset. These settings can affect the model's performance, speed, and accuracy. Some common YOLO training settings include the batch size, learning rate, momentum, and weight decay. Other factors that may affect the training process include the choice of optimizer, the choice of loss function, and the size and composition of the training dataset. It is important to carefully tune and experiment with these settings to achieve the best possible performance for a given task.
|
Training settings for YOLO models refer to the various hyperparameters and configurations used to train the model on a dataset. These settings can affect the model's performance, speed, and accuracy. Some common YOLO training settings include the batch size, learning rate, momentum, and weight decay. Other factors that may affect the training process include the choice of optimizer, the choice of loss function, and the size and composition of the training dataset. It is important to carefully tune and experiment with these settings to achieve the best possible performance for a given task.
|
||||||
|
|
||||||
| Key | Value | Description |
|
| Key | Value | Description |
|
||||||
|-------------------|----------|-----------------------------------------------------------------------------------|
|
|-------------------|----------|------------------------------------------------------------------------------------------------|
|
||||||
| `model` | `None` | path to model file, i.e. yolov8n.pt, yolov8n.yaml |
|
| `model` | `None` | path to model file, i.e. yolov8n.pt, yolov8n.yaml |
|
||||||
| `data` | `None` | path to data file, i.e. coco128.yaml |
|
| `data` | `None` | path to data file, i.e. coco128.yaml |
|
||||||
| `epochs` | `100` | number of epochs to train for |
|
| `epochs` | `100` | number of epochs to train for |
|
||||||
@ -169,6 +169,7 @@ Training settings for YOLO models refer to the various hyperparameters and confi
|
|||||||
| `amp` | `True` | Automatic Mixed Precision (AMP) training, choices=[True, False] |
|
| `amp` | `True` | Automatic Mixed Precision (AMP) training, choices=[True, False] |
|
||||||
| `fraction` | `1.0` | dataset fraction to train on (default is 1.0, all images in train set) |
|
| `fraction` | `1.0` | dataset fraction to train on (default is 1.0, all images in train set) |
|
||||||
| `profile` | `False` | profile ONNX and TensorRT speeds during training for loggers |
|
| `profile` | `False` | profile ONNX and TensorRT speeds during training for loggers |
|
||||||
|
| `freeze` | `None` | (int or list, optional) freeze first n layers, or freeze list of layer indices during training |
|
||||||
| `lr0` | `0.01` | initial learning rate (i.e. SGD=1E-2, Adam=1E-3) |
|
| `lr0` | `0.01` | initial learning rate (i.e. SGD=1E-2, Adam=1E-3) |
|
||||||
| `lrf` | `0.01` | final learning rate (lr0 * lrf) |
|
| `lrf` | `0.01` | final learning rate (lr0 * lrf) |
|
||||||
| `momentum` | `0.937` | SGD momentum/Adam beta1 |
|
| `momentum` | `0.937` | SGD momentum/Adam beta1 |
|
||||||
|
@ -79,7 +79,7 @@ include:
|
|||||||
The training settings for YOLO models encompass various hyperparameters and configurations used during the training process. These settings influence the model's performance, speed, and accuracy. Key training settings include batch size, learning rate, momentum, and weight decay. Additionally, the choice of optimizer, loss function, and training dataset composition can impact the training process. Careful tuning and experimentation with these settings are crucial for optimizing performance.
|
The training settings for YOLO models encompass various hyperparameters and configurations used during the training process. These settings influence the model's performance, speed, and accuracy. Key training settings include batch size, learning rate, momentum, and weight decay. Additionally, the choice of optimizer, loss function, and training dataset composition can impact the training process. Careful tuning and experimentation with these settings are crucial for optimizing performance.
|
||||||
|
|
||||||
| Key | Value | Description |
|
| Key | Value | Description |
|
||||||
|-------------------|----------|-----------------------------------------------------------------------------------|
|
|-------------------|----------|------------------------------------------------------------------------------------------------|
|
||||||
| `model` | `None` | path to model file, i.e. yolov8n.pt, yolov8n.yaml |
|
| `model` | `None` | path to model file, i.e. yolov8n.pt, yolov8n.yaml |
|
||||||
| `data` | `None` | path to data file, i.e. coco128.yaml |
|
| `data` | `None` | path to data file, i.e. coco128.yaml |
|
||||||
| `epochs` | `100` | number of epochs to train for |
|
| `epochs` | `100` | number of epochs to train for |
|
||||||
@ -107,6 +107,7 @@ The training settings for YOLO models encompass various hyperparameters and conf
|
|||||||
| `amp` | `True` | Automatic Mixed Precision (AMP) training, choices=[True, False] |
|
| `amp` | `True` | Automatic Mixed Precision (AMP) training, choices=[True, False] |
|
||||||
| `fraction` | `1.0` | dataset fraction to train on (default is 1.0, all images in train set) |
|
| `fraction` | `1.0` | dataset fraction to train on (default is 1.0, all images in train set) |
|
||||||
| `profile` | `False` | profile ONNX and TensorRT speeds during training for loggers |
|
| `profile` | `False` | profile ONNX and TensorRT speeds during training for loggers |
|
||||||
|
| `freeze` | `None` | (int or list, optional) freeze first n layers, or freeze list of layer indices during training |
|
||||||
| `lr0` | `0.01` | initial learning rate (i.e. SGD=1E-2, Adam=1E-3) |
|
| `lr0` | `0.01` | initial learning rate (i.e. SGD=1E-2, Adam=1E-3) |
|
||||||
| `lrf` | `0.01` | final learning rate (lr0 * lrf) |
|
| `lrf` | `0.01` | final learning rate (lr0 * lrf) |
|
||||||
| `momentum` | `0.937` | SGD momentum/Adam beta1 |
|
| `momentum` | `0.937` | SGD momentum/Adam beta1 |
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
__version__ = '8.0.153'
|
__version__ = '8.0.154'
|
||||||
|
|
||||||
from ultralytics.hub import start
|
from ultralytics.hub import start
|
||||||
from ultralytics.models import RTDETR, SAM, YOLO
|
from ultralytics.models import RTDETR, SAM, YOLO
|
||||||
|
@ -32,6 +32,7 @@ resume: False # (bool) resume training from last checkpoint
|
|||||||
amp: True # (bool) Automatic Mixed Precision (AMP) training, choices=[True, False], True runs AMP check
|
amp: True # (bool) Automatic Mixed Precision (AMP) training, choices=[True, False], True runs AMP check
|
||||||
fraction: 1.0 # (float) dataset fraction to train on (default is 1.0, all images in train set)
|
fraction: 1.0 # (float) dataset fraction to train on (default is 1.0, all images in train set)
|
||||||
profile: False # (bool) profile ONNX and TensorRT speeds during training for loggers
|
profile: False # (bool) profile ONNX and TensorRT speeds during training for loggers
|
||||||
|
freeze: None # (int | list, optional) freeze first n layers, or freeze list of layer indices during training
|
||||||
# Segmentation
|
# Segmentation
|
||||||
overlap_mask: True # (bool) masks should overlap during training (segment train only)
|
overlap_mask: True # (bool) masks should overlap during training (segment train only)
|
||||||
mask_ratio: 4 # (int) mask downsample ratio (segment train only)
|
mask_ratio: 4 # (int) mask downsample ratio (segment train only)
|
||||||
|
@ -207,11 +207,28 @@ class BaseTrainer:
|
|||||||
"""
|
"""
|
||||||
Builds dataloaders and optimizer on correct rank process.
|
Builds dataloaders and optimizer on correct rank process.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
self.run_callbacks('on_pretrain_routine_start')
|
self.run_callbacks('on_pretrain_routine_start')
|
||||||
ckpt = self.setup_model()
|
ckpt = self.setup_model()
|
||||||
self.model = self.model.to(self.device)
|
self.model = self.model.to(self.device)
|
||||||
self.set_model_attributes()
|
self.set_model_attributes()
|
||||||
|
|
||||||
|
# Freeze layers
|
||||||
|
freeze_list = self.args.freeze if isinstance(
|
||||||
|
self.args.freeze, list) else range(self.args.freeze) if isinstance(self.args.freeze, int) else []
|
||||||
|
always_freeze_names = ['.dfl'] # always freeze these layers
|
||||||
|
freeze_layer_names = [f'model.{x}.' for x in freeze_list] + always_freeze_names
|
||||||
|
for k, v in self.model.named_parameters():
|
||||||
|
# v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
|
||||||
|
if any(x in k for x in freeze_layer_names):
|
||||||
|
LOGGER.info(f"Freezing layer '{k}'")
|
||||||
|
v.requires_grad = False
|
||||||
|
elif not v.requires_grad:
|
||||||
|
LOGGER.info(f"WARNING ⚠️ setting 'requires_grad=True' for frozen layer '{k}'. "
|
||||||
|
'See ultralytics.engine.trainer for customization of frozen layers.')
|
||||||
|
v.requires_grad = True
|
||||||
|
|
||||||
# Check AMP
|
# Check AMP
|
||||||
self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
|
self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
|
||||||
if self.amp and RANK in (-1, 0): # Single-GPU and DDP
|
if self.amp and RANK in (-1, 0): # Single-GPU and DDP
|
||||||
@ -224,9 +241,11 @@ class BaseTrainer:
|
|||||||
self.scaler = amp.GradScaler(enabled=self.amp)
|
self.scaler = amp.GradScaler(enabled=self.amp)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
self.model = DDP(self.model, device_ids=[RANK])
|
self.model = DDP(self.model, device_ids=[RANK])
|
||||||
|
|
||||||
# Check imgsz
|
# Check imgsz
|
||||||
gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride)
|
gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride)
|
||||||
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
|
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
|
||||||
|
|
||||||
# Batch size
|
# Batch size
|
||||||
if self.batch_size == -1:
|
if self.batch_size == -1:
|
||||||
if RANK == -1: # single-GPU only, estimate best batch size
|
if RANK == -1: # single-GPU only, estimate best batch size
|
||||||
@ -272,7 +291,6 @@ class BaseTrainer:
|
|||||||
"""Train completed, evaluate and plot if specified by arguments."""
|
"""Train completed, evaluate and plot if specified by arguments."""
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
self._setup_ddp(world_size)
|
self._setup_ddp(world_size)
|
||||||
|
|
||||||
self._setup_train(world_size)
|
self._setup_train(world_size)
|
||||||
|
|
||||||
self.epoch_time = None
|
self.epoch_time = None
|
||||||
|
@ -142,8 +142,9 @@ class GhostConv(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class RepConv(nn.Module):
|
class RepConv(nn.Module):
|
||||||
"""RepConv is a basic rep-style block, including training and deploy status
|
"""
|
||||||
This code is based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
|
RepConv is a basic rep-style block, including training and deploy status. This module is used in RT-DETR.
|
||||||
|
Based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
|
||||||
"""
|
"""
|
||||||
default_act = nn.SiLU() # default activation
|
default_act = nn.SiLU() # default activation
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user