mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-07-08 06:34:23 +08:00
Update generate_ddp_file
for improved overrides
(#2909)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
facb7861cf
commit
305cde69d0
@ -182,7 +182,7 @@ class BaseTrainer:
|
|||||||
# Command
|
# Command
|
||||||
cmd, file = generate_ddp_command(world_size, self)
|
cmd, file = generate_ddp_command(world_size, self)
|
||||||
try:
|
try:
|
||||||
LOGGER.info(f'Running DDP command {cmd}')
|
LOGGER.info(f'DDP command: {cmd}')
|
||||||
subprocess.run(cmd, check=True)
|
subprocess.run(cmd, check=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
@ -195,7 +195,7 @@ class BaseTrainer:
|
|||||||
"""Initializes and sets the DistributedDataParallel parameters for training."""
|
"""Initializes and sets the DistributedDataParallel parameters for training."""
|
||||||
torch.cuda.set_device(RANK)
|
torch.cuda.set_device(RANK)
|
||||||
self.device = torch.device('cuda', RANK)
|
self.device = torch.device('cuda', RANK)
|
||||||
LOGGER.info(f'DDP settings: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
||||||
os.environ['NCCL_BLOCKING_WAIT'] = '1' # set to enforce timeout
|
os.environ['NCCL_BLOCKING_WAIT'] = '1' # set to enforce timeout
|
||||||
dist.init_process_group('nccl' if dist.is_nccl_available() else 'gloo',
|
dist.init_process_group('nccl' if dist.is_nccl_available() else 'gloo',
|
||||||
timeout=timedelta(seconds=3600),
|
timeout=timedelta(seconds=3600),
|
||||||
|
@ -27,10 +27,13 @@ def generate_ddp_file(trainer):
|
|||||||
"""Generates a DDP file and returns its file name."""
|
"""Generates a DDP file and returns its file name."""
|
||||||
module, name = f'{trainer.__class__.__module__}.{trainer.__class__.__name__}'.rsplit('.', 1)
|
module, name = f'{trainer.__class__.__module__}.{trainer.__class__.__name__}'.rsplit('.', 1)
|
||||||
|
|
||||||
content = f'''cfg = {vars(trainer.args)} \nif __name__ == "__main__":
|
content = f'''overrides = {vars(trainer.args)} \nif __name__ == "__main__":
|
||||||
from {module} import {name}
|
from {module} import {name}
|
||||||
|
from ultralytics.yolo.utils import DEFAULT_CFG_DICT
|
||||||
|
|
||||||
trainer = {name}(cfg=cfg)
|
cfg = DEFAULT_CFG_DICT.copy()
|
||||||
|
cfg.update(save_dir='') # handle the extra key 'save_dir'
|
||||||
|
trainer = {name}(cfg=cfg, overrides=overrides)
|
||||||
trainer.train()'''
|
trainer.train()'''
|
||||||
(USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True)
|
(USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True)
|
||||||
with tempfile.NamedTemporaryFile(prefix='_temp_',
|
with tempfile.NamedTemporaryFile(prefix='_temp_',
|
||||||
@ -54,9 +57,7 @@ def generate_ddp_command(world_size, trainer):
|
|||||||
file = generate_ddp_file(trainer)
|
file = generate_ddp_file(trainer)
|
||||||
dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch'
|
dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch'
|
||||||
port = find_free_network_port()
|
port = find_free_network_port()
|
||||||
exclude_args = ['save_dir']
|
cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file]
|
||||||
args = [f'{k}={v}' for k, v in vars(trainer.args).items() if k not in exclude_args]
|
|
||||||
cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file] + args
|
|
||||||
return cmd, file
|
return cmd, file
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
if overrides is None:
|
if overrides is None:
|
||||||
overrides = {}
|
overrides = {}
|
||||||
overrides['task'] = 'classify'
|
overrides['task'] = 'classify'
|
||||||
if overrides.get('imgsz') is None and cfg['imgsz'] == DEFAULT_CFG.imgsz == 640:
|
if overrides.get('imgsz') is None:
|
||||||
overrides['imgsz'] = 224
|
overrides['imgsz'] = 224
|
||||||
super().__init__(cfg, overrides, _callbacks)
|
super().__init__(cfg, overrides, _callbacks)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user