diff --git a/tests/test_cuda.py b/tests/test_cuda.py index 3585715b..3f9d5332 100644 --- a/tests/test_cuda.py +++ b/tests/test_cuda.py @@ -28,7 +28,14 @@ def test_checks(): @pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available') def test_train(): device = 0 if CUDA_DEVICE_COUNT == 1 else [0, 1] - YOLO(MODEL).train(data=DATA, imgsz=64, epochs=1, batch=-1, device=device) # also test AutoBatch, requires imgsz>=64 + YOLO(MODEL).train(data=DATA, imgsz=64, epochs=1, device=device) # requires imgsz>=64 + + +@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available') +def test_autobatch(): + from ultralytics.utils.autobatch import check_train_batch_size + + check_train_batch_size(YOLO(MODEL).model.cuda(), imgsz=128, amp=True) @pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available') diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index 84167df1..d06a2e42 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -164,7 +164,7 @@ class BaseTrainer: """Allow device='', device=None on Multi-GPU systems to default to device=0.""" if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3' world_size = len(self.args.device.split(',')) - elif isinstance(self.args.device, tuple): # multi devices from cli is tuple type + elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list) world_size = len(self.args.device) elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number world_size = 1 # default to device 0