mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
Deterministic training (#53)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
793dde365d
commit
c5f5b80c04
@ -28,16 +28,19 @@ from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT
|
|||||||
from ultralytics.yolo.utils.checks import print_args
|
from ultralytics.yolo.utils.checks import print_args
|
||||||
from ultralytics.yolo.utils.files import increment_path, save_yaml
|
from ultralytics.yolo.utils.files import increment_path, save_yaml
|
||||||
from ultralytics.yolo.utils.modeling import get_model
|
from ultralytics.yolo.utils.modeling import get_model
|
||||||
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, one_cycle
|
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle
|
||||||
|
|
||||||
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
|
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
|
||||||
|
RANK = int(os.getenv('RANK', -1))
|
||||||
|
|
||||||
|
|
||||||
class BaseTrainer:
|
class BaseTrainer:
|
||||||
|
|
||||||
def __init__(self, config=DEFAULT_CONFIG, overrides={}):
|
def __init__(self, config=DEFAULT_CONFIG, overrides={}):
|
||||||
self.console = LOGGER
|
|
||||||
self.args = self._get_config(config, overrides)
|
self.args = self._get_config(config, overrides)
|
||||||
|
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
|
||||||
|
|
||||||
|
self.console = LOGGER
|
||||||
self.validator = None
|
self.validator = None
|
||||||
self.model = None
|
self.model = None
|
||||||
self.callbacks = defaultdict(list)
|
self.callbacks = defaultdict(list)
|
||||||
|
@ -22,6 +22,7 @@ pretrained: False
|
|||||||
optimizer: 'SGD' # choices=['SGD', 'Adam', 'AdamW', 'RMSProp']
|
optimizer: 'SGD' # choices=['SGD', 'Adam', 'AdamW', 'RMSProp']
|
||||||
verbose: False
|
verbose: False
|
||||||
seed: 0
|
seed: 0
|
||||||
|
deterministic: True
|
||||||
local_rank: -1
|
local_rank: -1
|
||||||
single_cls: False # train multi-class data as single-class
|
single_cls: False # train multi-class data as single-class
|
||||||
image_weights: False # use weighted image selection for training
|
image_weights: False # use weighted image selection for training
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
|
import random
|
||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import thop
|
import thop
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -199,6 +201,21 @@ def one_cycle(y1=0.0, y2=1.0, steps=100):
|
|||||||
return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
|
return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
|
||||||
|
|
||||||
|
|
||||||
|
def init_seeds(seed=0, deterministic=False):
|
||||||
|
# Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
|
||||||
|
# torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287
|
||||||
|
if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213
|
||||||
|
torch.use_deterministic_algorithms(True)
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
|
||||||
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||||
|
|
||||||
|
|
||||||
class ModelEMA:
|
class ModelEMA:
|
||||||
""" Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
|
""" Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
|
||||||
Keeps a moving average of everything in the model state_dict (parameters and buffers)
|
Keeps a moving average of everything in the model state_dict (parameters and buffers)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user