mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
DDP and new dataloader Fix (#95)
Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
16e3c08883
commit
4fb04be20b
@ -51,15 +51,15 @@ repos:
|
|||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
- mdformat-gfm
|
- mdformat-gfm
|
||||||
- mdformat-black
|
- mdformat-black
|
||||||
exclude: "README.md|README_cn.md| CONTRIBUTING.md"
|
exclude: "README.md|README.zh-CN.md|CONTRIBUTING.md"
|
||||||
|
|
||||||
- repo: https://github.com/asottile/yesqa
|
|
||||||
rev: v1.4.0
|
|
||||||
hooks:
|
|
||||||
- id: yesqa
|
|
||||||
|
|
||||||
- repo: https://github.com/PyCQA/flake8
|
- repo: https://github.com/PyCQA/flake8
|
||||||
rev: 5.0.4
|
rev: 5.0.4
|
||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
name: PEP8
|
name: PEP8
|
||||||
|
|
||||||
|
#- repo: https://github.com/asottile/yesqa
|
||||||
|
# rev: v1.4.0
|
||||||
|
# hooks:
|
||||||
|
# - id: yesqa
|
@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from ultralytics import yolo # (required for python usage)
|
from ultralytics import yolo # noqa required for python usage
|
||||||
# from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
# from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
||||||
from ultralytics.yolo.utils import LOGGER
|
from ultralytics.yolo.utils import LOGGER
|
||||||
|
@ -13,7 +13,6 @@ from pathlib import Path
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.multiprocessing as mp
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from torch.cuda import amp
|
from torch.cuda import amp
|
||||||
@ -111,8 +110,12 @@ class BaseTrainer:
|
|||||||
if world_size > 1 and "LOCAL_RANK" not in os.environ:
|
if world_size > 1 and "LOCAL_RANK" not in os.environ:
|
||||||
command = generate_ddp_command(world_size, self)
|
command = generate_ddp_command(world_size, self)
|
||||||
print('DDP command: ', command)
|
print('DDP command: ', command)
|
||||||
subprocess.Popen(command)
|
try:
|
||||||
# ddp_cleanup(command, self) # TODO: uncomment and fix
|
subprocess.run(command)
|
||||||
|
except Exception as e:
|
||||||
|
self.console(e)
|
||||||
|
finally:
|
||||||
|
ddp_cleanup(command, self)
|
||||||
else:
|
else:
|
||||||
self._do_train(int(os.getenv("RANK", -1)), world_size)
|
self._do_train(int(os.getenv("RANK", -1)), world_size)
|
||||||
|
|
||||||
@ -122,7 +125,6 @@ class BaseTrainer:
|
|||||||
torch.cuda.set_device(rank)
|
torch.cuda.set_device(rank)
|
||||||
self.device = torch.device('cuda', rank)
|
self.device = torch.device('cuda', rank)
|
||||||
self.console.info(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ")
|
self.console.info(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ")
|
||||||
mp.set_start_method('spawn', force=True)
|
|
||||||
dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
|
dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
|
||||||
|
|
||||||
def _setup_train(self, rank, world_size):
|
def _setup_train(self, rank, world_size):
|
||||||
@ -159,8 +161,8 @@ class BaseTrainer:
|
|||||||
if rank in {0, -1}:
|
if rank in {0, -1}:
|
||||||
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode="val")
|
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode="val")
|
||||||
self.validator = self.get_validator()
|
self.validator = self.get_validator()
|
||||||
# metric_keys = self.validator.metric_keys + self.label_loss_items(prefix="val")
|
metric_keys = self.validator.metric_keys + self.label_loss_items(prefix="val")
|
||||||
# self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
|
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
|
||||||
self.ema = ModelEMA(self.model)
|
self.ema = ModelEMA(self.model)
|
||||||
self.trigger_callbacks("on_pretrain_routine_end")
|
self.trigger_callbacks("on_pretrain_routine_end")
|
||||||
|
|
||||||
|
@ -3,7 +3,8 @@ import shutil
|
|||||||
import socket
|
import socket
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
|
||||||
|
from . import USER_CONFIG_DIR
|
||||||
|
|
||||||
|
|
||||||
def find_free_network_port() -> int:
|
def find_free_network_port() -> int:
|
||||||
@ -23,25 +24,25 @@ def find_free_network_port() -> int:
|
|||||||
def generate_ddp_file(trainer):
|
def generate_ddp_file(trainer):
|
||||||
import_path = '.'.join(str(trainer.__class__).split(".")[1:-1])
|
import_path = '.'.join(str(trainer.__class__).split(".")[1:-1])
|
||||||
|
|
||||||
# remove the save_dir
|
shutil.rmtree(trainer.save_dir) # remove the save_dir
|
||||||
shutil.rmtree(trainer.save_dir)
|
|
||||||
content = f'''overrides = {dict(trainer.args)} \nif __name__ == "__main__":
|
content = f'''overrides = {dict(trainer.args)} \nif __name__ == "__main__":
|
||||||
from ultralytics.{import_path} import {trainer.__class__.__name__}
|
from ultralytics.{import_path} import {trainer.__class__.__name__}
|
||||||
|
|
||||||
trainer = {trainer.__class__.__name__}(overrides=overrides)
|
trainer = {trainer.__class__.__name__}(overrides=overrides)
|
||||||
trainer.train()'''
|
trainer.train()'''
|
||||||
|
(USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True)
|
||||||
with tempfile.NamedTemporaryFile(prefix="_temp_",
|
with tempfile.NamedTemporaryFile(prefix="_temp_",
|
||||||
suffix=f"{id(trainer)}.py",
|
suffix=f"{id(trainer)}.py",
|
||||||
mode="w+",
|
mode="w+",
|
||||||
encoding='utf-8',
|
encoding='utf-8',
|
||||||
dir=os.path.curdir,
|
dir=USER_CONFIG_DIR / 'DDP',
|
||||||
delete=False) as file:
|
delete=False) as file:
|
||||||
file.write(content)
|
file.write(content)
|
||||||
return file.name
|
return file.name
|
||||||
|
|
||||||
|
|
||||||
def generate_ddp_command(world_size, trainer):
|
def generate_ddp_command(world_size, trainer):
|
||||||
import __main__ # local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
|
import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
|
||||||
file_name = os.path.abspath(sys.argv[0])
|
file_name = os.path.abspath(sys.argv[0])
|
||||||
using_cli = not file_name.endswith(".py")
|
using_cli = not file_name.endswith(".py")
|
||||||
if using_cli:
|
if using_cli:
|
||||||
@ -52,9 +53,7 @@ def generate_ddp_command(world_size, trainer):
|
|||||||
|
|
||||||
|
|
||||||
def ddp_cleanup(command, trainer):
|
def ddp_cleanup(command, trainer):
|
||||||
# delete temp file if created
|
# delete temp file if created
|
||||||
# TODO: this is a temp solution in case the file is deleted before DDP launching
|
|
||||||
time.sleep(5)
|
|
||||||
tempfile_suffix = f"{id(trainer)}.py"
|
tempfile_suffix = f"{id(trainer)}.py"
|
||||||
if tempfile_suffix in "".join(command):
|
if tempfile_suffix in "".join(command):
|
||||||
for chunk in command:
|
for chunk in command:
|
||||||
|
@ -58,8 +58,6 @@ class DetectionTrainer(BaseTrainer):
|
|||||||
model = DetectionModel(model_cfg or weights["model"].yaml, ch=3, nc=self.data["nc"])
|
model = DetectionModel(model_cfg or weights["model"].yaml, ch=3, nc=self.data["nc"])
|
||||||
if weights:
|
if weights:
|
||||||
model.load(weights)
|
model.load(weights)
|
||||||
for _, v in model.named_parameters():
|
|
||||||
v.requires_grad = True # train all layers
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def get_validator(self):
|
def get_validator(self):
|
||||||
|
@ -21,8 +21,6 @@ class SegmentationTrainer(DetectionTrainer):
|
|||||||
model = SegmentationModel(model_cfg or weights["model"].yaml, ch=3, nc=self.data["nc"])
|
model = SegmentationModel(model_cfg or weights["model"].yaml, ch=3, nc=self.data["nc"])
|
||||||
if weights:
|
if weights:
|
||||||
model.load(weights)
|
model.load(weights)
|
||||||
for _, v in model.named_parameters():
|
|
||||||
v.requires_grad = True # train all layers
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def get_validator(self):
|
def get_validator(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user