mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-10-25 02:05:38 +08:00
Fix DDP when device is a list (#4600)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
23b4f697c9
commit
53b4f8c713
@ -28,7 +28,14 @@ def test_checks():
|
|||||||
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
|
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
|
||||||
def test_train():
|
def test_train():
|
||||||
device = 0 if CUDA_DEVICE_COUNT == 1 else [0, 1]
|
device = 0 if CUDA_DEVICE_COUNT == 1 else [0, 1]
|
||||||
YOLO(MODEL).train(data=DATA, imgsz=64, epochs=1, batch=-1, device=device) # also test AutoBatch, requires imgsz>=64
|
YOLO(MODEL).train(data=DATA, imgsz=64, epochs=1, device=device) # requires imgsz>=64
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
|
||||||
|
def test_autobatch():
|
||||||
|
from ultralytics.utils.autobatch import check_train_batch_size
|
||||||
|
|
||||||
|
check_train_batch_size(YOLO(MODEL).model.cuda(), imgsz=128, amp=True)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
|
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
|
||||||
|
@ -164,7 +164,7 @@ class BaseTrainer:
|
|||||||
"""Allow device='', device=None on Multi-GPU systems to default to device=0."""
|
"""Allow device='', device=None on Multi-GPU systems to default to device=0."""
|
||||||
if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
|
if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
|
||||||
world_size = len(self.args.device.split(','))
|
world_size = len(self.args.device.split(','))
|
||||||
elif isinstance(self.args.device, tuple): # multi devices from cli is tuple type
|
elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
|
||||||
world_size = len(self.args.device)
|
world_size = len(self.args.device)
|
||||||
elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
|
elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
|
||||||
world_size = 1 # default to device 0
|
world_size = 1 # default to device 0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user