Fix device counting method to account for double-digit device IDs (#8502)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Omar Duhaiby 2024-03-01 12:19:06 +01:00 committed by GitHub
parent 59ed47c448
commit 3c6b9b6688
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -115,7 +115,7 @@ def select_device(device="", batch=0, newline=False, verbose=True):
device = "0"
visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
os.environ["CUDA_VISIBLE_DEVICES"] = device # set environment variable - must be before assert is_available()
if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(",", ""))):
if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.split(","))):
LOGGER.info(s)
install = (
"See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no "