mirror of
				https://github.com/THU-MIG/yolov10.git
				synced 2025-10-31 06:15:39 +08:00 
			
		
		
		
	ultralytics 8.0.19 seg/det dataset warning and DDP-cls/seg fixes (#595)
				
					
				
			Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: 曾逸夫(Zeng Yifu) <41098760+Zengyf-CVer@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									936414c615
								
							
						
					
					
						commit
						520825c4b2
					
				| @ -14,7 +14,7 @@ ci: | ||||
| 
 | ||||
| repos: | ||||
|   - repo: https://github.com/pre-commit/pre-commit-hooks | ||||
|     rev: v4.3.0 | ||||
|     rev: v4.4.0 | ||||
|     hooks: | ||||
|       # - id: end-of-file-fixer | ||||
|       - id: trailing-whitespace | ||||
| @ -25,14 +25,14 @@ repos: | ||||
|       - id: check-docstring-first | ||||
| 
 | ||||
|   - repo: https://github.com/asottile/pyupgrade | ||||
|     rev: v2.37.3 | ||||
|     rev: v3.3.1 | ||||
|     hooks: | ||||
|       - id: pyupgrade | ||||
|         name: Upgrade code | ||||
|         args: [ --py37-plus ] | ||||
| 
 | ||||
|   - repo: https://github.com/PyCQA/isort | ||||
|     rev: 5.10.1 | ||||
|     rev: 5.11.4 | ||||
|     hooks: | ||||
|       - id: isort | ||||
|         name: Sort imports | ||||
| @ -59,6 +59,13 @@ repos: | ||||
|       - id: flake8 | ||||
|         name: PEP8 | ||||
| 
 | ||||
|   - repo: https://github.com/codespell-project/codespell | ||||
|     rev: v2.2.2 | ||||
|     hooks: | ||||
|       - id: codespell | ||||
|         args: | ||||
|           - --ignore-words-list=crate,nd | ||||
| 
 | ||||
|   #- repo: https://github.com/asottile/yesqa | ||||
|   #  rev: v1.4.0 | ||||
|   #  hooks: | ||||
|  | ||||
| @ -183,7 +183,7 @@ Default arguments can be overriden by simply passing them as arguments in the CL | ||||
| You can override the `default.yaml` config file entirely by passing a new file with the `cfg` arguments, | ||||
| i.e. `cfg=custom.yaml`. | ||||
| 
 | ||||
| To do this first create a copy of `default.yaml` in your current working dir with the `yolo copy-config` command. | ||||
| To do this first create a copy of `default.yaml` in your current working dir with the `yolo copy-cfg` command. | ||||
| 
 | ||||
| This will create `default_copy.yaml`, which you can then pass as `cfg=default_copy.yaml` along with any additional args, | ||||
| like `imgsz=320` in this example: | ||||
| @ -192,6 +192,6 @@ like `imgsz=320` in this example: | ||||
| 
 | ||||
|     === "CLI" | ||||
|         ```bash | ||||
|         yolo copy-config | ||||
|         yolo copy-cfg | ||||
|         yolo cfg=default_copy.yaml imgsz=320 | ||||
|         ``` | ||||
| @ -638,11 +638,11 @@ | ||||
|     { | ||||
|       "cell_type": "code", | ||||
|       "source": [ | ||||
|         "# Load YOLOv8n-cls, train it on imagenette160 for 3 epochs and predict an image with it\n", | ||||
|         "# Load YOLOv8n-cls, train it on mnist160 for 3 epochs and predict an image with it\n", | ||||
|         "from ultralytics import YOLO\n", | ||||
|         "\n", | ||||
|         "model = YOLO('yolov8n-cls.pt')  # load a pretrained YOLOv8n classification model\n", | ||||
|         "model.train(data='imagenette160', epochs=3)  # train the model\n", | ||||
|         "model.train(data='mnist160', epochs=3)  # train the model\n", | ||||
|         "model('https://ultralytics.com/images/bus.jpg')  # predict on an image" | ||||
|       ], | ||||
|       "metadata": { | ||||
|  | ||||
| @ -3,13 +3,13 @@ | ||||
| from pathlib import Path | ||||
| 
 | ||||
| from ultralytics.yolo.cfg import get_cfg | ||||
| from ultralytics.yolo.utils import DEFAULT_CFG_PATH, ROOT, SETTINGS | ||||
| from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, SETTINGS | ||||
| from ultralytics.yolo.v8 import classify, detect, segment | ||||
| 
 | ||||
| CFG_DET = 'yolov8n.yaml' | ||||
| CFG_SEG = 'yolov8n-seg.yaml' | ||||
| CFG_CLS = 'squeezenet1_0' | ||||
| CFG = get_cfg(DEFAULT_CFG_PATH) | ||||
| CFG = get_cfg(DEFAULT_CFG) | ||||
| MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n' | ||||
| SOURCE = ROOT / "assets" | ||||
| 
 | ||||
|  | ||||
| @ -313,13 +313,39 @@ class ClassificationModel(BaseModel): | ||||
| # Functions ------------------------------------------------------------------------------------------------------------ | ||||
| 
 | ||||
| 
 | ||||
| def torch_safe_load(weight): | ||||
|     """ | ||||
|     This function attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it | ||||
|     catches the error, logs a warning message, and attempts to install the missing module via the check_requirements() | ||||
|     function. After installation, the function again attempts to load the model using torch.load(). | ||||
| 
 | ||||
|     Args: | ||||
|         weight (str): The file path of the PyTorch model. | ||||
| 
 | ||||
|     Returns: | ||||
|         The loaded PyTorch model. | ||||
|     """ | ||||
|     from ultralytics.yolo.utils.downloads import attempt_download | ||||
| 
 | ||||
|     file = attempt_download(weight)  # search online if missing locally | ||||
|     try: | ||||
|         return torch.load(file, map_location='cpu')  # load | ||||
|     except ModuleNotFoundError as e: | ||||
|         if e.name == 'omegaconf':  # e.name is missing module name | ||||
|             LOGGER.warning(f"WARNING ⚠️ {weight} requires {e.name}, which is not in ultralytics requirements." | ||||
|                            f"\nAutoInstall will run now for {e.name} but this feature will be removed in the future." | ||||
|                            f"\nRecommend fixes are to train a new model using updated ultraltyics package or to " | ||||
|                            f"download updated models from https://github.com/ultralytics/assets/releases/tag/v0.0.0") | ||||
|         check_requirements(e.name)  # install missing module | ||||
|         return torch.load(file, map_location='cpu')  # load | ||||
| 
 | ||||
| 
 | ||||
| def attempt_load_weights(weights, device=None, inplace=True, fuse=False): | ||||
|     # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a | ||||
|     from ultralytics.yolo.utils.downloads import attempt_download | ||||
| 
 | ||||
|     model = Ensemble() | ||||
|     for w in weights if isinstance(weights, list) else [weights]: | ||||
|         ckpt = torch.load(attempt_download(w), map_location='cpu')  # load | ||||
|         ckpt = torch_safe_load(w)  # load ckpt | ||||
|         args = {**DEFAULT_CFG_DICT, **ckpt['train_args']}  # combine model and default args, preferring model args | ||||
|         ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float()  # FP32 model | ||||
| 
 | ||||
| @ -355,18 +381,7 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False): | ||||
| 
 | ||||
| def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False): | ||||
|     # Loads a single model weights | ||||
|     from ultralytics.yolo.utils.downloads import attempt_download | ||||
| 
 | ||||
|     weight = attempt_download(weight) | ||||
|     try: | ||||
|         ckpt = torch.load(weight, map_location='cpu')  # load | ||||
|     except ModuleNotFoundError: | ||||
|         LOGGER.warning(f"WARNING ⚠️ {weight} is deprecated as it requires omegaconf, which is now removed from " | ||||
|                        "ultralytics requirements.\nAutoInstall will occur now but this feature will be removed for " | ||||
|                        "omegaconf models in the future.\nPlease train a new model or download updated models " | ||||
|                        "from https://github.com/ultralytics/assets/releases/tag/v0.0.0") | ||||
|         check_requirements('omegaconf') | ||||
|         ckpt = torch.load(weight, map_location='cpu')  # load | ||||
|     ckpt = torch_safe_load(weight)  # load ckpt | ||||
|     args = {**DEFAULT_CFG_DICT, **ckpt['train_args']}  # combine model and default args, preferring model args | ||||
|     model = (ckpt.get('ema') or ckpt['model']).to(device).float()  # FP32 model | ||||
| 
 | ||||
|  | ||||
| @ -611,6 +611,8 @@ class LoadImagesAndLabels(Dataset): | ||||
| 
 | ||||
|     def cache_labels(self, path=Path('./labels.cache'), prefix=''): | ||||
|         # Cache dataset labels, check images and read shapes | ||||
|         if path.exists(): | ||||
|             path.unlink()  # remove *.cache file if exists | ||||
|         x = {}  # dict | ||||
|         nm, nf, ne, nc, msgs = 0, 0, 0, 0, []  # number missing, found, empty, corrupt, messages | ||||
|         desc = f"{prefix}Scanning {path.parent / path.stem}..." | ||||
|  | ||||
| @ -47,6 +47,8 @@ class YOLODataset(BaseDataset): | ||||
| 
 | ||||
|     def cache_labels(self, path=Path("./labels.cache")): | ||||
|         # Cache dataset labels, check images and read shapes | ||||
|         if path.exists(): | ||||
|             path.unlink()  # remove *.cache file if exists | ||||
|         x = {"labels": []} | ||||
|         nm, nf, ne, nc, msgs = 0, 0, 0, 0, []  # number missing, found, empty, corrupt, messages | ||||
|         desc = f"{self.prefix}Scanning {path.parent / path.stem}..." | ||||
| @ -85,7 +87,7 @@ class YOLODataset(BaseDataset): | ||||
|         x["results"] = nf, nm, ne, nc, len(self.im_files) | ||||
|         x["msgs"] = msgs  # warnings | ||||
|         x["version"] = self.cache_version  # cache version | ||||
|         self.im_files = [lb["im_file"] for lb in x["labels"]] | ||||
|         self.im_files = [lb["im_file"] for lb in x["labels"]]  # update im_files | ||||
|         if is_dir_writeable(path.parent): | ||||
|             np.save(str(path), x)  # save cache for next time | ||||
|             path.with_suffix(".cache.npy").rename(path)  # remove .npy suffix | ||||
| @ -116,6 +118,17 @@ class YOLODataset(BaseDataset): | ||||
|         # Read cache | ||||
|         [cache.pop(k) for k in ("hash", "version", "msgs")]  # remove items | ||||
|         labels = cache["labels"] | ||||
| 
 | ||||
|         # Check if the dataset is all boxes or all segments | ||||
|         len_boxes = sum(len(lb["bboxes"]) for lb in labels) | ||||
|         len_segments = sum(len(lb["segments"]) for lb in labels) | ||||
|         if len_segments and len_boxes != len_segments: | ||||
|             LOGGER.warning( | ||||
|                 f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, " | ||||
|                 f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. " | ||||
|                 "To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.") | ||||
|             for lb in labels: | ||||
|                 lb["segments"] = [] | ||||
|         nl = len(np.concatenate([label["cls"] for label in labels], 0))  # number of labels | ||||
|         assert nl > 0, f"{self.prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}" | ||||
|         return labels | ||||
|  | ||||
| @ -14,7 +14,7 @@ import numpy as np | ||||
| import torch | ||||
| from PIL import ExifTags, Image, ImageOps | ||||
| 
 | ||||
| from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, ROOT, colorstr, yaml_load | ||||
| from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, ROOT, colorstr, emojis, yaml_load | ||||
| from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii | ||||
| from ultralytics.yolo.utils.downloads import download | ||||
| from ultralytics.yolo.utils.files import unzip_file | ||||
| @ -202,7 +202,10 @@ def check_det_dataset(dataset, autodownload=True): | ||||
| 
 | ||||
|     # Checks | ||||
|     for k in 'train', 'val', 'names': | ||||
|         assert k in data, f"data.yaml '{k}:' field missing ❌" | ||||
|         if k not in data: | ||||
|             raise SyntaxError( | ||||
|                 emojis(f"{dataset} '{k}:' key missing ❌.\n" | ||||
|                        f"'train', 'val' and 'names' are required in data.yaml files.")) | ||||
|     if isinstance(data['names'], (list, tuple)):  # old array format | ||||
|         data['names'] = dict(enumerate(data['names']))  # convert to dict | ||||
|     data['nc'] = len(data['names']) | ||||
|  | ||||
| @ -388,7 +388,7 @@ class Exporter: | ||||
|     @try_export | ||||
|     def _export_engine(self, workspace=4, verbose=False, prefix=colorstr('TensorRT:')): | ||||
|         # YOLOv8 TensorRT export https://developer.nvidia.com/tensorrt | ||||
|         assert self.im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `device==0`' | ||||
|         assert self.im.device.type != 'cpu', "export running on CPU but must be on GPU, i.e. use 'device=0'" | ||||
|         try: | ||||
|             import tensorrt as trt  # noqa | ||||
|         except ImportError: | ||||
|  | ||||
| @ -53,7 +53,12 @@ class YOLO: | ||||
|         self.overrides = {}  # overrides for trainer object | ||||
| 
 | ||||
|         # Load or create new YOLO model | ||||
|         {'.pt': self._load, '.yaml': self._new}[Path(model).suffix](model) | ||||
|         load_methods = {'.pt': self._load, '.yaml': self._new} | ||||
|         suffix = Path(model).suffix | ||||
|         if suffix in load_methods: | ||||
|             {'.pt': self._load, '.yaml': self._new}[suffix](model) | ||||
|         else: | ||||
|             raise NotImplementedError(f"'{suffix}' model loading not implemented") | ||||
| 
 | ||||
|     def __call__(self, source=None, stream=False, verbose=False, **kwargs): | ||||
|         return self.predict(source, stream, verbose, **kwargs) | ||||
|  | ||||
| @ -35,7 +35,7 @@ from ultralytics.nn.autobackend import AutoBackend | ||||
| from ultralytics.yolo.cfg import get_cfg | ||||
| from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams | ||||
| from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS | ||||
| from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, SETTINGS, callbacks, colorstr, ops | ||||
| from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, SETTINGS, callbacks, colorstr, ops | ||||
| from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow | ||||
| from ultralytics.yolo.utils.files import increment_path | ||||
| from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode | ||||
| @ -61,12 +61,12 @@ class BasePredictor: | ||||
|         data_path (str): Path to data. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, cfg=DEFAULT_CFG_PATH, overrides=None): | ||||
|     def __init__(self, cfg=DEFAULT_CFG, overrides=None): | ||||
|         """ | ||||
|         Initializes the BasePredictor class. | ||||
| 
 | ||||
|         Args: | ||||
|             cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG. | ||||
|             cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG. | ||||
|             overrides (dict, optional): Configuration overrides. Defaults to None. | ||||
|         """ | ||||
|         self.args = get_cfg(cfg, overrides) | ||||
|  | ||||
| @ -24,8 +24,8 @@ from ultralytics import __version__ | ||||
| from ultralytics.nn.tasks import attempt_load_one_weight | ||||
| from ultralytics.yolo.cfg import get_cfg | ||||
| from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset | ||||
| from ultralytics.yolo.utils import (DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, | ||||
|                                     emojis, yaml_save) | ||||
| from ultralytics.yolo.utils import (DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis, | ||||
|                                     yaml_save) | ||||
| from ultralytics.yolo.utils.autobatch import check_train_batch_size | ||||
| from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args | ||||
| from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command | ||||
| @ -71,12 +71,12 @@ class BaseTrainer: | ||||
|         csv (Path): Path to results CSV file. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, cfg=DEFAULT_CFG_PATH, overrides=None): | ||||
|     def __init__(self, cfg=DEFAULT_CFG, overrides=None): | ||||
|         """ | ||||
|         Initializes the BaseTrainer class. | ||||
| 
 | ||||
|         Args: | ||||
|             cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG. | ||||
|             cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG. | ||||
|             overrides (dict, optional): Configuration overrides. Defaults to None. | ||||
|         """ | ||||
|         self.args = get_cfg(cfg, overrides) | ||||
|  | ||||
| @ -10,7 +10,7 @@ from tqdm import tqdm | ||||
| from ultralytics.nn.autobackend import AutoBackend | ||||
| from ultralytics.yolo.cfg import get_cfg | ||||
| from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset | ||||
| from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, emojis | ||||
| from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, emojis | ||||
| from ultralytics.yolo.utils.checks import check_imgsz | ||||
| from ultralytics.yolo.utils.files import increment_path | ||||
| from ultralytics.yolo.utils.ops import Profile | ||||
| @ -52,7 +52,7 @@ class BaseValidator: | ||||
|         self.dataloader = dataloader | ||||
|         self.pbar = pbar | ||||
|         self.logger = logger or LOGGER | ||||
|         self.args = args or get_cfg(DEFAULT_CFG_PATH) | ||||
|         self.args = args or get_cfg(DEFAULT_CFG) | ||||
|         self.model = None | ||||
|         self.data = None | ||||
|         self.device = None | ||||
|  | ||||
| @ -127,8 +127,7 @@ def is_colab(): | ||||
|     Returns: | ||||
|         bool: True if running inside a Colab notebook, False otherwise. | ||||
|     """ | ||||
|     # Check if the 'google.colab' module is present in sys.modules | ||||
|     return 'google.colab' in sys.modules | ||||
|     return 'COLAB_RELEASE_TAG' in os.environ or 'COLAB_BACKEND_VERSION' in os.environ | ||||
| 
 | ||||
| 
 | ||||
| def is_kaggle(): | ||||
|  | ||||
| @ -224,7 +224,7 @@ def check_file(file, suffix=''): | ||||
|         for d in 'models', 'yolo/data':  # search directories | ||||
|             files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True))  # find file | ||||
|         if not files: | ||||
|             raise FileNotFoundError(f"{file} does not exist") | ||||
|             raise FileNotFoundError(f"'{file}' does not exist") | ||||
|         elif len(files) > 1: | ||||
|             raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}") | ||||
|         return files[0]  # return file | ||||
|  | ||||
| @ -10,17 +10,14 @@ from . import USER_CONFIG_DIR | ||||
| 
 | ||||
| 
 | ||||
| def find_free_network_port() -> int: | ||||
|     # https://github.com/Lightning-AI/lightning/blob/master/src/lightning_lite/plugins/environments/lightning.py | ||||
|     """Finds a free port on localhost. | ||||
| 
 | ||||
|     It is useful in single-node training when we don't want to connect to a real main node but have to set the | ||||
|     `MASTER_PORT` environment variable. | ||||
|     """ | ||||
|     s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | ||||
|     s.bind(("", 0)) | ||||
|     port = s.getsockname()[1] | ||||
|     s.close() | ||||
|     return port | ||||
|     with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | ||||
|         s.bind(('127.0.0.1', 0)) | ||||
|         return s.getsockname()[1]  # port | ||||
| 
 | ||||
| 
 | ||||
| def generate_ddp_file(trainer): | ||||
|  | ||||
| @ -91,12 +91,10 @@ def attempt_download(file, repo='ultralytics/assets', release='v0.0.0'): | ||||
| 
 | ||||
|         file.parent.mkdir(parents=True, exist_ok=True)  # make parent dir (if required) | ||||
|         if name in assets: | ||||
|             url3 = 'https://drive.google.com/drive/folders/1EFQTEUeXWSFww0luse2jB9M1QNZQGwNl'  # backup gdrive mirror | ||||
|             safe_download( | ||||
|                 file, | ||||
|             safe_download(file, | ||||
|                           url=f'https://github.com/{repo}/releases/download/{tag}/{name}', | ||||
|                           min_bytes=1E5, | ||||
|                 error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/{tag} or {url3}') | ||||
|                           error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/{tag}') | ||||
| 
 | ||||
|         return str(file) | ||||
| 
 | ||||
|  | ||||
| @ -58,7 +58,7 @@ def DDP_model(model): | ||||
|         return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK) | ||||
| 
 | ||||
| 
 | ||||
| def select_device(device='', batch_size=0, newline=False): | ||||
| def select_device(device='', batch=0, newline=False): | ||||
|     # device = None or 'cpu' or 0 or '0' or '0,1,2,3' | ||||
|     ver = git_describe() or ultralytics.__version__  # git commit or pip package version | ||||
|     s = f'Ultralytics YOLOv{ver} 🚀 Python-{platform.python_version()} torch-{torch.__version__} ' | ||||
| @ -71,14 +71,15 @@ def select_device(device='', batch_size=0, newline=False): | ||||
|         os.environ['CUDA_VISIBLE_DEVICES'] = '-1'  # force torch.cuda.is_available() = False | ||||
|     elif device:  # non-cpu device requested | ||||
|         os.environ['CUDA_VISIBLE_DEVICES'] = device  # set environment variable - must be before assert is_available() | ||||
|         assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \ | ||||
|             f"Invalid CUDA 'device={device}' requested, use 'device=cpu' or pass valid CUDA device(s)" | ||||
|         if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', ''))): | ||||
|             raise ValueError(f"Invalid CUDA 'device={device}' requested, use 'device=cpu' or pass valid CUDA device(s)") | ||||
| 
 | ||||
|     if not cpu and not mps and torch.cuda.is_available():  # prefer GPU if available | ||||
|         devices = device.split(',') if device else '0'  # range(torch.cuda.device_count())  # i.e. 0,1,6,7 | ||||
|         n = len(devices)  # device count | ||||
|         if n > 1 and batch_size > 0:  # check batch_size is divisible by device_count | ||||
|             assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}' | ||||
|         if n > 1 and batch > 0 and batch % n != 0:  # check batch_size is divisible by device_count | ||||
|             raise ValueError(f'batch={batch} is not multiple of GPU count {n}.\n' | ||||
|                              f'Try batch={batch // n} or batch={batch // n + 1}') | ||||
|         space = ' ' * (len(s) + 1) | ||||
|         for i, d in enumerate(devices): | ||||
|             p = torch.cuda.get_device_properties(i) | ||||
|  | ||||
| @ -13,11 +13,11 @@ from ultralytics.yolo.utils.torch_utils import strip_optimizer | ||||
| 
 | ||||
| class ClassificationTrainer(BaseTrainer): | ||||
| 
 | ||||
|     def __init__(self, config=DEFAULT_CFG, overrides=None): | ||||
|     def __init__(self, cfg=DEFAULT_CFG, overrides=None): | ||||
|         if overrides is None: | ||||
|             overrides = {} | ||||
|         overrides["task"] = "classify" | ||||
|         super().__init__(config, overrides) | ||||
|         super().__init__(cfg, overrides) | ||||
| 
 | ||||
|     def set_model_attributes(self): | ||||
|         self.model.names = self.data["names"] | ||||
|  | ||||
| @ -47,7 +47,7 @@ class ClassificationValidator(BaseValidator): | ||||
| 
 | ||||
| def val(cfg=DEFAULT_CFG): | ||||
|     cfg.model = cfg.model or "yolov8n-cls.pt"  # or "resnet18" | ||||
|     cfg.data = cfg.data or "imagenette160" | ||||
|     cfg.data = cfg.data or "mnist160" | ||||
|     validator = ClassificationValidator(args=cfg) | ||||
|     validator(model=cfg.model) | ||||
| 
 | ||||
|  | ||||
| @ -18,11 +18,11 @@ from ultralytics.yolo.v8.detect.train import Loss | ||||
| # BaseTrainer python usage | ||||
| class SegmentationTrainer(v8.detect.DetectionTrainer): | ||||
| 
 | ||||
|     def __init__(self, config=DEFAULT_CFG, overrides=None): | ||||
|     def __init__(self, cfg=DEFAULT_CFG, overrides=None): | ||||
|         if overrides is None: | ||||
|             overrides = {} | ||||
|         overrides["task"] = "segment" | ||||
|         super().__init__(config, overrides) | ||||
|         super().__init__(cfg, overrides) | ||||
| 
 | ||||
|     def get_model(self, cfg=None, weights=None, verbose=True): | ||||
|         model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose) | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Glenn Jocher
						Glenn Jocher