mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
DDP and new dataloader Fix (#95)
Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
16e3c08883
commit
4fb04be20b
@ -51,15 +51,15 @@ repos:
|
||||
additional_dependencies:
|
||||
- mdformat-gfm
|
||||
- mdformat-black
|
||||
exclude: "README.md|README_cn.md| CONTRIBUTING.md"
|
||||
|
||||
- repo: https://github.com/asottile/yesqa
|
||||
rev: v1.4.0
|
||||
hooks:
|
||||
- id: yesqa
|
||||
exclude: "README.md|README.zh-CN.md|CONTRIBUTING.md"
|
||||
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 5.0.4
|
||||
hooks:
|
||||
- id: flake8
|
||||
name: PEP8
|
||||
|
||||
#- repo: https://github.com/asottile/yesqa
|
||||
# rev: v1.4.0
|
||||
# hooks:
|
||||
# - id: yesqa
|
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
from ultralytics import yolo # (required for python usage)
|
||||
from ultralytics import yolo # noqa required for python usage
|
||||
# from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
||||
from ultralytics.yolo.utils import LOGGER
|
||||
|
@ -13,7 +13,6 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from omegaconf import OmegaConf
|
||||
from torch.cuda import amp
|
||||
@ -111,8 +110,12 @@ class BaseTrainer:
|
||||
if world_size > 1 and "LOCAL_RANK" not in os.environ:
|
||||
command = generate_ddp_command(world_size, self)
|
||||
print('DDP command: ', command)
|
||||
subprocess.Popen(command)
|
||||
# ddp_cleanup(command, self) # TODO: uncomment and fix
|
||||
try:
|
||||
subprocess.run(command)
|
||||
except Exception as e:
|
||||
self.console(e)
|
||||
finally:
|
||||
ddp_cleanup(command, self)
|
||||
else:
|
||||
self._do_train(int(os.getenv("RANK", -1)), world_size)
|
||||
|
||||
@ -122,7 +125,6 @@ class BaseTrainer:
|
||||
torch.cuda.set_device(rank)
|
||||
self.device = torch.device('cuda', rank)
|
||||
self.console.info(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ")
|
||||
mp.set_start_method('spawn', force=True)
|
||||
dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
|
||||
|
||||
def _setup_train(self, rank, world_size):
|
||||
@ -159,8 +161,8 @@ class BaseTrainer:
|
||||
if rank in {0, -1}:
|
||||
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode="val")
|
||||
self.validator = self.get_validator()
|
||||
# metric_keys = self.validator.metric_keys + self.label_loss_items(prefix="val")
|
||||
# self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
|
||||
metric_keys = self.validator.metric_keys + self.label_loss_items(prefix="val")
|
||||
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
|
||||
self.ema = ModelEMA(self.model)
|
||||
self.trigger_callbacks("on_pretrain_routine_end")
|
||||
|
||||
|
@ -3,7 +3,8 @@ import shutil
|
||||
import socket
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
from . import USER_CONFIG_DIR
|
||||
|
||||
|
||||
def find_free_network_port() -> int:
|
||||
@ -23,25 +24,25 @@ def find_free_network_port() -> int:
|
||||
def generate_ddp_file(trainer):
|
||||
import_path = '.'.join(str(trainer.__class__).split(".")[1:-1])
|
||||
|
||||
# remove the save_dir
|
||||
shutil.rmtree(trainer.save_dir)
|
||||
shutil.rmtree(trainer.save_dir) # remove the save_dir
|
||||
content = f'''overrides = {dict(trainer.args)} \nif __name__ == "__main__":
|
||||
from ultralytics.{import_path} import {trainer.__class__.__name__}
|
||||
|
||||
trainer = {trainer.__class__.__name__}(overrides=overrides)
|
||||
trainer.train()'''
|
||||
(USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True)
|
||||
with tempfile.NamedTemporaryFile(prefix="_temp_",
|
||||
suffix=f"{id(trainer)}.py",
|
||||
mode="w+",
|
||||
encoding='utf-8',
|
||||
dir=os.path.curdir,
|
||||
dir=USER_CONFIG_DIR / 'DDP',
|
||||
delete=False) as file:
|
||||
file.write(content)
|
||||
return file.name
|
||||
|
||||
|
||||
def generate_ddp_command(world_size, trainer):
|
||||
import __main__ # local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
|
||||
import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
|
||||
file_name = os.path.abspath(sys.argv[0])
|
||||
using_cli = not file_name.endswith(".py")
|
||||
if using_cli:
|
||||
@ -52,9 +53,7 @@ def generate_ddp_command(world_size, trainer):
|
||||
|
||||
|
||||
def ddp_cleanup(command, trainer):
|
||||
# delete temp file if created
|
||||
# TODO: this is a temp solution in case the file is deleted before DDP launching
|
||||
time.sleep(5)
|
||||
# delete temp file if created
|
||||
tempfile_suffix = f"{id(trainer)}.py"
|
||||
if tempfile_suffix in "".join(command):
|
||||
for chunk in command:
|
||||
|
@ -58,8 +58,6 @@ class DetectionTrainer(BaseTrainer):
|
||||
model = DetectionModel(model_cfg or weights["model"].yaml, ch=3, nc=self.data["nc"])
|
||||
if weights:
|
||||
model.load(weights)
|
||||
for _, v in model.named_parameters():
|
||||
v.requires_grad = True # train all layers
|
||||
return model
|
||||
|
||||
def get_validator(self):
|
||||
|
@ -21,8 +21,6 @@ class SegmentationTrainer(DetectionTrainer):
|
||||
model = SegmentationModel(model_cfg or weights["model"].yaml, ch=3, nc=self.data["nc"])
|
||||
if weights:
|
||||
model.load(weights)
|
||||
for _, v in model.named_parameters():
|
||||
v.requires_grad = True # train all layers
|
||||
return model
|
||||
|
||||
def get_validator(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user