mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-24 05:55:51 +08:00
CLI DDP fixes (#135)
This commit is contained in:
parent
8f3cd52844
commit
c5c86a3acd
@ -5,7 +5,7 @@ import time
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
from ultralytics.hub.config import HUB_API_ROOT
|
from ultralytics.hub.config import HUB_API_ROOT
|
||||||
from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, LOGGER, RANK, SETTINGS, colorstr, emojis, yaml_load
|
from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, LOGGER, RANK, SETTINGS, colorstr, emojis
|
||||||
|
|
||||||
PREFIX = colorstr('Ultralytics: ')
|
PREFIX = colorstr('Ultralytics: ')
|
||||||
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
|
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
|
||||||
@ -49,7 +49,7 @@ def request_with_credentials(url: str) -> any:
|
|||||||
|
|
||||||
|
|
||||||
# Deprecated TODO: eliminate this function?
|
# Deprecated TODO: eliminate this function?
|
||||||
def split_key(key: str = '') -> tuple[str, str]:
|
def split_key(key=''):
|
||||||
"""
|
"""
|
||||||
Verify and split a 'api_key[sep]model_id' string, sep is one of '.' or '_'
|
Verify and split a 'api_key[sep]model_id' string, sep is one of '.' or '_'
|
||||||
|
|
||||||
|
@ -29,10 +29,11 @@ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
|
|||||||
@contextmanager
|
@contextmanager
|
||||||
def torch_distributed_zero_first(local_rank: int):
|
def torch_distributed_zero_first(local_rank: int):
|
||||||
# Decorator to make all processes in distributed training wait for each local_master to do something
|
# Decorator to make all processes in distributed training wait for each local_master to do something
|
||||||
if local_rank not in {-1, 0}:
|
initialized = torch.distributed.is_initialized() # prevent 'Default process group has not been initialized' errors
|
||||||
|
if initialized and local_rank not in {-1, 0}:
|
||||||
dist.barrier(device_ids=[local_rank])
|
dist.barrier(device_ids=[local_rank])
|
||||||
yield
|
yield
|
||||||
if local_rank == 0:
|
if initialized and local_rank == 0:
|
||||||
dist.barrier(device_ids=[0])
|
dist.barrier(device_ids=[0])
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user