mirror of
				https://github.com/THU-MIG/yolov10.git
				synced 2025-10-25 18:35:40 +08:00 
			
		
		
		
	General ultralytics==8.0.6 updates (#351)
				
					
				
			Co-authored-by: Dzmitry Plashchynski <plashchynski@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									70427579b8
								
							
						
					
					
						commit
						f8e32c4c13
					
				
							
								
								
									
										12
									
								
								.github/workflows/ci.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										12
									
								
								.github/workflows/ci.yaml
									
									
									
									
										vendored
									
									
								
							| @ -84,22 +84,22 @@ jobs: | ||||
|       - name: Test detection | ||||
|         shell: bash  # for Windows compatibility | ||||
|         run: | | ||||
|           yolo task=detect mode=train model=yolov8n.yaml data=coco8.yaml epochs=1 imgsz=32 | ||||
|           yolo task=detect mode=val model=runs/detect/train/weights/last.pt imgsz=32 | ||||
|           yolo task=detect mode=train data=coco8.yaml model=yolov8n.yaml epochs=1 imgsz=32 | ||||
|           yolo task=detect mode=val data=coco8.yaml model=runs/detect/train/weights/last.pt imgsz=32 | ||||
|           yolo task=detect mode=predict model=runs/detect/train/weights/last.pt imgsz=32 source=ultralytics/assets/bus.jpg | ||||
|           yolo mode=export model=runs/detect/train/weights/last.pt imgsz=32 format=torchscript | ||||
|       - name: Test segmentation | ||||
|         shell: bash  # for Windows compatibility | ||||
|         run: | | ||||
|           yolo task=segment mode=train model=yolov8n-seg.yaml data=coco8-seg.yaml epochs=1 imgsz=32 | ||||
|           yolo task=segment mode=val model=runs/segment/train/weights/last.pt data=coco8-seg.yaml imgsz=32 | ||||
|           yolo task=segment mode=train data=coco8-seg.yaml model=yolov8n-seg.yaml epochs=1 imgsz=32 | ||||
|           yolo task=segment mode=val data=coco8-seg.yaml model=runs/segment/train/weights/last.pt imgsz=32 | ||||
|           yolo task=segment mode=predict model=runs/segment/train/weights/last.pt imgsz=32 source=ultralytics/assets/bus.jpg | ||||
|           yolo mode=export model=runs/segment/train/weights/last.pt imgsz=32 format=torchscript | ||||
|       - name: Test classification | ||||
|         shell: bash  # for Windows compatibility | ||||
|         run: | | ||||
|           yolo task=classify mode=train model=yolov8n-cls.yaml data=mnist160 epochs=1 imgsz=32 | ||||
|           yolo task=classify mode=val model=runs/classify/train/weights/last.pt data=mnist160 imgsz=32 | ||||
|           yolo task=classify mode=train data=mnist160 model=yolov8n-cls.yaml epochs=1 imgsz=32 | ||||
|           yolo task=classify mode=val data=mnist160 model=runs/classify/train/weights/last.pt imgsz=32 | ||||
|           yolo task=classify mode=predict model=runs/classify/train/weights/last.pt imgsz=32 source=ultralytics/assets/bus.jpg | ||||
|           yolo mode=export model=runs/classify/train/weights/last.pt imgsz=32 format=torchscript | ||||
|       - name: Pytest tests | ||||
|  | ||||
| @ -52,7 +52,7 @@ ENV OMP_NUM_THREADS=1 | ||||
| # t=ultralytics/ultralytics:latest tnew=ultralytics/ultralytics:v6.2 && sudo docker pull $t && sudo docker tag $t $tnew && sudo docker push $tnew | ||||
| 
 | ||||
| # Clean up | ||||
| # docker system prune -a --volumes | ||||
| # sudo docker system prune -a --volumes | ||||
| 
 | ||||
| # Update Ubuntu drivers | ||||
| # https://www.maketecheasier.com/install-nvidia-drivers-ubuntu/ | ||||
|  | ||||
| @ -1,13 +1,16 @@ | ||||
| # Ultralytics YOLO 🚀, GPL-3.0 license | ||||
| 
 | ||||
| from pathlib import Path | ||||
| 
 | ||||
| from ultralytics.yolo.configs import get_config | ||||
| from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT | ||||
| from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, SETTINGS | ||||
| from ultralytics.yolo.v8 import classify, detect, segment | ||||
| 
 | ||||
| CFG_DET = 'yolov8n.yaml' | ||||
| CFG_SEG = 'yolov8n-seg.yaml' | ||||
| CFG_CLS = 'squeezenet1_0' | ||||
| CFG = get_config(DEFAULT_CONFIG) | ||||
| MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n' | ||||
| SOURCE = ROOT / "assets" | ||||
| 
 | ||||
| 
 | ||||
| @ -18,15 +21,14 @@ def test_detect(): | ||||
|     # Trainer | ||||
|     trainer = detect.DetectionTrainer(overrides=overrides) | ||||
|     trainer.train() | ||||
|     trained_model = trainer.best | ||||
| 
 | ||||
|     # Validator | ||||
|     val = detect.DetectionValidator(args=CFG) | ||||
|     val(model=trained_model) | ||||
|     val(model=trainer.best)  # validate best.pt | ||||
| 
 | ||||
|     # Predictor | ||||
|     pred = detect.DetectionPredictor(overrides={"imgsz": [64, 64]}) | ||||
|     result = pred(source=SOURCE, model="yolov8n.pt", return_outputs=True) | ||||
|     result = pred(source=SOURCE, model=f"{MODEL}.pt", return_outputs=True) | ||||
|     assert len(list(result)), "predictor test failed" | ||||
| 
 | ||||
|     overrides["resume"] = trainer.last | ||||
| @ -49,15 +51,14 @@ def test_segment(): | ||||
|     # trainer | ||||
|     trainer = segment.SegmentationTrainer(overrides=overrides) | ||||
|     trainer.train() | ||||
|     trained_model = trainer.best | ||||
| 
 | ||||
|     # Validator | ||||
|     val = segment.SegmentationValidator(args=CFG) | ||||
|     val(model=trained_model) | ||||
|     val(model=trainer.best)  # validate best.pt | ||||
| 
 | ||||
|     # Predictor | ||||
|     pred = segment.SegmentationPredictor(overrides={"imgsz": [64, 64]}) | ||||
|     result = pred(source=SOURCE, model="yolov8n-seg.pt", return_outputs=True) | ||||
|     result = pred(source=SOURCE, model=f"{MODEL}-seg.pt", return_outputs=True) | ||||
|     assert len(list(result)) == 2, "predictor test failed" | ||||
| 
 | ||||
|     # Test resume | ||||
| @ -82,13 +83,12 @@ def test_classify(): | ||||
|     # Trainer | ||||
|     trainer = classify.ClassificationTrainer(overrides=overrides) | ||||
|     trainer.train() | ||||
|     trained_model = trainer.best | ||||
| 
 | ||||
|     # Validator | ||||
|     val = classify.ClassificationValidator(args=CFG) | ||||
|     val(model=trained_model) | ||||
|     val(model=trainer.best) | ||||
| 
 | ||||
|     # Predictor | ||||
|     pred = classify.ClassificationPredictor(overrides={"imgsz": [64, 64]}) | ||||
|     result = pred(source=SOURCE, model=trained_model, return_outputs=True) | ||||
|     result = pred(source=SOURCE, model=trainer.best, return_outputs=True) | ||||
|     assert len(list(result)) == 2, "predictor test failed" | ||||
|  | ||||
| @ -1,7 +1,5 @@ | ||||
| # Ultralytics YOLO 🚀, GPL-3.0 license | ||||
| 
 | ||||
| import signal | ||||
| import sys | ||||
| from pathlib import Path | ||||
| from time import sleep | ||||
| 
 | ||||
| @ -15,19 +13,21 @@ AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__versio | ||||
| 
 | ||||
| session = None | ||||
| 
 | ||||
| 
 | ||||
| def signal_handler(signum, frame): | ||||
|     """ Confirm exit """ | ||||
|     global hub_logger | ||||
|     LOGGER.info(f'Signal received. {signum} {frame}') | ||||
|     if isinstance(session, HubTrainingSession): | ||||
|         hub_logger.alive = False | ||||
|         del hub_logger | ||||
|     sys.exit(signum) | ||||
| 
 | ||||
| 
 | ||||
| signal.signal(signal.SIGTERM, signal_handler) | ||||
| signal.signal(signal.SIGINT, signal_handler) | ||||
| # Causing problems in tests (non-authenticated) | ||||
| # import signal | ||||
| # import sys | ||||
| # def signal_handler(signum, frame): | ||||
| #     """ Confirm exit """ | ||||
| #     global hub_logger | ||||
| #     LOGGER.info(f'Signal received. {signum} {frame}') | ||||
| #     if isinstance(session, HubTrainingSession): | ||||
| #         hub_logger.alive = False | ||||
| #         del hub_logger | ||||
| #     sys.exit(signum) | ||||
| # | ||||
| # | ||||
| # signal.signal(signal.SIGTERM, signal_handler) | ||||
| # signal.signal(signal.SIGINT, signal_handler) | ||||
| 
 | ||||
| 
 | ||||
| class HubTrainingSession: | ||||
|  | ||||
| @ -8,13 +8,13 @@ from omegaconf import DictConfig, OmegaConf | ||||
| from ultralytics.yolo.configs.hydra_patch import check_config_mismatch | ||||
| 
 | ||||
| 
 | ||||
| def get_config(config: Union[str, DictConfig], overrides: Union[str, Dict] = None): | ||||
| def get_config(config: Union[str, Path, DictConfig], overrides: Union[str, Dict] = None): | ||||
|     """ | ||||
|     Load and merge configuration data from a file or dictionary. | ||||
| 
 | ||||
|     Args: | ||||
|         config (Union[str, DictConfig]): Configuration data in the form of a file name or a DictConfig object. | ||||
|         overrides (Union[str, Dict], optional): Overrides in the form of a file name or a dictionary. Default is None. | ||||
|         config (str) or (Path) or (DictConfig): Configuration data in the form of a file name or a DictConfig object. | ||||
|         overrides (str) or(Dict), optional: Overrides in the form of a file name or a dictionary. Default is None. | ||||
| 
 | ||||
|     Returns: | ||||
|         OmegaConf.Namespace: Training arguments namespace. | ||||
|  | ||||
| @ -14,12 +14,11 @@ import numpy as np | ||||
| import torch | ||||
| from PIL import ExifTags, Image, ImageOps | ||||
| 
 | ||||
| from ultralytics.yolo.utils import LOGGER, ROOT, colorstr, yaml_load | ||||
| from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, ROOT, colorstr, yaml_load | ||||
| from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii | ||||
| from ultralytics.yolo.utils.downloads import download | ||||
| from ultralytics.yolo.utils.files import unzip_file | ||||
| 
 | ||||
| from ..utils.ops import segments2boxes | ||||
| from ultralytics.yolo.utils.ops import segments2boxes | ||||
| 
 | ||||
| HELP_URL = "See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data" | ||||
| IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm"  # include image suffixes | ||||
| @ -173,12 +172,7 @@ def polygons2masks_overlap(imgsz, segments, downsample_ratio=1): | ||||
|     areas = [] | ||||
|     ms = [] | ||||
|     for si in range(len(segments)): | ||||
|         mask = polygon2mask( | ||||
|             imgsz, | ||||
|             [segments[si].reshape(-1)], | ||||
|             downsample_ratio=downsample_ratio, | ||||
|             color=1, | ||||
|         ) | ||||
|         mask = polygon2mask(imgsz, [segments[si].reshape(-1)], downsample_ratio=downsample_ratio, color=1) | ||||
|         ms.append(mask) | ||||
|         areas.append(mask.sum()) | ||||
|     areas = np.asarray(areas) | ||||
| @ -194,13 +188,14 @@ def polygons2masks_overlap(imgsz, segments, downsample_ratio=1): | ||||
| def check_dataset_yaml(data, autodownload=True): | ||||
|     # Download, check and/or unzip dataset if not found locally | ||||
|     data = check_file(data) | ||||
|     DATASETS_DIR = (Path.cwd() / "../datasets").resolve()  # TODO: handle global dataset dir | ||||
| 
 | ||||
|     # Download (optional) | ||||
|     extract_dir = '' | ||||
|     if isinstance(data, (str, Path)) and (is_zipfile(data) or is_tarfile(data)): | ||||
|         download(data, dir=f'{DATASETS_DIR}/{Path(data).stem}', unzip=True, delete=False, curl=False, threads=1) | ||||
|         data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml')) | ||||
|         extract_dir, autodownload = data.parent, False | ||||
| 
 | ||||
|     # Read yaml (optional) | ||||
|     if isinstance(data, (str, Path)): | ||||
|         data = yaml_load(data, append_filename=True)  # dictionary | ||||
| @ -215,7 +210,7 @@ def check_dataset_yaml(data, autodownload=True): | ||||
|     # Resolve paths | ||||
|     path = Path(extract_dir or data.get('path') or '')  # optional 'path' default to '.' | ||||
|     if not path.is_absolute(): | ||||
|         path = (Path.cwd() / path).resolve() | ||||
|         path = (DATASETS_DIR / path).resolve() | ||||
|         data['path'] = path  # download scripts | ||||
|     for k in 'train', 'val', 'test': | ||||
|         if data.get(k):  # prepend path | ||||
| @ -253,6 +248,7 @@ def check_dataset_yaml(data, autodownload=True): | ||||
|             s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌" | ||||
|             LOGGER.info(f"Dataset download {s}") | ||||
|     check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf', progress=True)  # download fonts | ||||
| 
 | ||||
|     return data  # dictionary | ||||
| 
 | ||||
| 
 | ||||
| @ -274,12 +270,12 @@ def check_dataset(dataset: str): | ||||
|             'nc': Number of classes in the dataset | ||||
|             'names': List of class names in the dataset | ||||
|     """ | ||||
|     data_dir = (Path.cwd() / "datasets" / dataset).resolve() | ||||
|     data_dir = (DATASETS_DIR / dataset).resolve() | ||||
|     if not data_dir.is_dir(): | ||||
|         LOGGER.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...') | ||||
|         t = time.time() | ||||
|         if dataset == 'imagenet': | ||||
|             subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True) | ||||
|             subprocess.run(f"bash {ROOT / 'yolo/data/scripts/get_imagenet.sh'}", shell=True, check=True) | ||||
|         else: | ||||
|             url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip' | ||||
|             download(url, dir=data_dir.parent) | ||||
|  | ||||
| @ -240,7 +240,7 @@ class BasePredictor: | ||||
|                 if isinstance(self.vid_writer[idx], cv2.VideoWriter): | ||||
|                     self.vid_writer[idx].release()  # release previous video writer | ||||
|                 if vid_cap:  # video | ||||
|                     fps = vid_cap.get(cv2.CAP_PROP_FPS) | ||||
|                     fps = int(vid_cap.get(cv2.CAP_PROP_FPS))  # integer required, floats produce error in MP4 codec | ||||
|                     w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | ||||
|                     h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | ||||
|                 else:  # stream | ||||
|  | ||||
| @ -506,9 +506,11 @@ class BaseTrainer: | ||||
|     def check_resume(self): | ||||
|         resume = self.args.resume | ||||
|         if resume: | ||||
|             last = Path(check_file(resume) if isinstance(resume, str) else get_latest_run()) | ||||
|             last = Path(check_file(resume) if isinstance(resume, (str, Path)) else get_latest_run()) | ||||
|             args_yaml = last.parent.parent / 'args.yaml'  # train options yaml | ||||
|             if args_yaml.is_file(): | ||||
|             assert args_yaml.is_file(), \ | ||||
|                 FileNotFoundError('Resume checkpoint f{last} not found. ' | ||||
|                                   'Please pass a valid checkpoint to resume from, i.e. yolo resume=path/to/last.pt') | ||||
|             args = get_config(args_yaml)  # replace | ||||
|             args.model, resume = str(last), True  # reinstate | ||||
|             self.args = args | ||||
|  | ||||
| @ -187,7 +187,7 @@ def get_git_root_dir(): | ||||
|     """ | ||||
|     try: | ||||
|         output = subprocess.run(["git", "rev-parse", "--git-dir"], capture_output=True, check=True) | ||||
|         return Path(output.stdout.strip().decode('utf-8')).parent  # parent/.git | ||||
|         return Path(output.stdout.strip().decode('utf-8')).parent.resolve()  # parent/.git | ||||
|     except subprocess.CalledProcessError: | ||||
|         return None | ||||
| 
 | ||||
| @ -348,16 +348,18 @@ def yaml_load(file='data.yaml', append_filename=False): | ||||
|         return {**yaml.safe_load(f), 'yaml_file': str(file)} if append_filename else yaml.safe_load(f) | ||||
| 
 | ||||
| 
 | ||||
| def get_settings(file=USER_CONFIG_DIR / 'settings.yaml'): | ||||
| def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.0'): | ||||
|     """ | ||||
|     Loads a global settings YAML file or creates one with default values if it does not exist. | ||||
|     Loads a global Ultralytics settings YAML file or creates one with default values if it does not exist. | ||||
| 
 | ||||
|     Args: | ||||
|         file (Path): Path to the settings YAML file. Defaults to 'settings.yaml' in the USER_CONFIG_DIR. | ||||
|         file (Path): Path to the Ultralytics settings YAML file. Defaults to 'settings.yaml' in the USER_CONFIG_DIR. | ||||
|         version (str): Settings version. If min settings version not met, new default settings will be saved. | ||||
| 
 | ||||
|     Returns: | ||||
|         dict: Dictionary of settings key-value pairs. | ||||
|     """ | ||||
|     from ultralytics.yolo.utils.checks import check_version | ||||
|     from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first | ||||
| 
 | ||||
|     root = get_git_root_dir() or Path('')  # not is_pip_package() | ||||
| @ -366,7 +368,8 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml'): | ||||
|         'weights_dir': str(root / 'weights'),  # default weights directory. | ||||
|         'runs_dir': str(root / 'runs'),  # default runs directory. | ||||
|         'sync': True,  # sync analytics to help with YOLO development | ||||
|         'uuid': uuid.getnode()}  # device UUID to align analytics | ||||
|         'uuid': uuid.getnode(),  # device UUID to align analytics | ||||
|         'settings_version': version}  # Ultralytics settings version | ||||
| 
 | ||||
|     with torch_distributed_zero_first(RANK): | ||||
|         if not file.exists(): | ||||
| @ -375,12 +378,14 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml'): | ||||
|         settings = yaml_load(file) | ||||
| 
 | ||||
|         # Check that settings keys and types match defaults | ||||
|         correct = settings.keys() == defaults.keys() and \ | ||||
|                   all(type(a) == type(b) for a, b in zip(settings.values(), defaults.values())) | ||||
|         correct = settings.keys() == defaults.keys() \ | ||||
|                   and all(type(a) == type(b) for a, b in zip(settings.values(), defaults.values())) \ | ||||
|                   and check_version(settings['settings_version'], version) | ||||
|         if not correct: | ||||
|             LOGGER.warning('WARNING ⚠️ Different global settings detected, resetting to defaults. ' | ||||
|                            'This may be due to an ultralytics package update. ' | ||||
|                            f'View and update your global settings directly in {file}') | ||||
|             LOGGER.warning('WARNING ⚠️ Ultralytics settings reset to defaults. ' | ||||
|                            '\nThis is normal and may be due to a recent ultralytics package update, ' | ||||
|                            'but may have overwritten previous settings. ' | ||||
|                            f"\nYou may view and update settings directly in '{file}'") | ||||
|             settings = defaults  # merge **defaults with **settings (prefer **settings) | ||||
|             yaml_save(file, settings)  # save updated defaults | ||||
| 
 | ||||
|  | ||||
| @ -3,8 +3,6 @@ | ||||
| import json | ||||
| from time import time | ||||
| 
 | ||||
| import torch | ||||
| 
 | ||||
| from ultralytics.hub.utils import PREFIX, sync_analytics | ||||
| from ultralytics.yolo.utils import LOGGER | ||||
| 
 | ||||
|  | ||||
| @ -252,7 +252,7 @@ class ConfusionMatrix: | ||||
|                        vmin=0.0, | ||||
|                        xticklabels=ticklabels, | ||||
|                        yticklabels=ticklabels).set_facecolor((1, 1, 1)) | ||||
|         ax.set_ylabel('True') | ||||
|         ax.set_xlabel('True') | ||||
|         ax.set_ylabel('Predicted') | ||||
|         ax.set_title('Confusion Matrix') | ||||
|         fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250) | ||||
|  | ||||
| @ -113,11 +113,10 @@ class ClassificationTrainer(BaseTrainer): | ||||
|         """ | ||||
|         # Not needed for classification but necessary for segmentation & detection | ||||
|         keys = [f"{prefix}/{x}" for x in self.loss_names] | ||||
|         if loss_items is not None: | ||||
|         if loss_items is None: | ||||
|             return keys | ||||
|         loss_items = [round(float(loss_items), 5)] | ||||
|         return dict(zip(keys, loss_items)) | ||||
|         else: | ||||
|             return keys | ||||
| 
 | ||||
|     def resume_training(self, ckpt): | ||||
|         pass | ||||
|  | ||||
| @ -48,14 +48,14 @@ class DetectionTrainer(BaseTrainer): | ||||
|         return batch | ||||
| 
 | ||||
|     def set_model_attributes(self): | ||||
|         nl = de_parallel(self.model).model[-1].nl  # number of detection layers (to scale hyps) | ||||
|         self.args.box *= 3 / nl  # scale to layers | ||||
|         # nl = de_parallel(self.model).model[-1].nl  # number of detection layers (to scale hyps) | ||||
|         # self.args.box *= 3 / nl  # scale to layers | ||||
|         # self.args.cls *= self.data["nc"] / 80 * 3 / nl  # scale to classes and layers | ||||
|         self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl  # scale to image size and layers | ||||
|         # self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl  # scale to image size and layers | ||||
|         self.model.nc = self.data["nc"]  # attach number of classes to model | ||||
|         self.model.names = self.data["names"]  # attach class names to model | ||||
|         self.model.args = self.args  # attach hyperparameters to model | ||||
|         # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc | ||||
|         self.model.names = self.data["names"] | ||||
| 
 | ||||
|     def get_model(self, cfg=None, weights=None, verbose=True): | ||||
|         model = DetectionModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose) | ||||
|  | ||||
| @ -6,8 +6,7 @@ import torch | ||||
| from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, ops | ||||
| from ultralytics.yolo.utils.checks import check_imgsz | ||||
| from ultralytics.yolo.utils.plotting import colors, save_one_box | ||||
| 
 | ||||
| from ..detect.predict import DetectionPredictor | ||||
| from ultralytics.yolo.v8.detect.predict import DetectionPredictor | ||||
| 
 | ||||
| 
 | ||||
| class SegmentationPredictor(DetectionPredictor): | ||||
|  | ||||
| @ -13,14 +13,15 @@ from ultralytics.yolo.utils.ops import crop_mask, xyxy2xywh | ||||
| from ultralytics.yolo.utils.plotting import plot_images, plot_results | ||||
| from ultralytics.yolo.utils.tal import make_anchors | ||||
| from ultralytics.yolo.utils.torch_utils import de_parallel | ||||
| 
 | ||||
| from ..detect.train import Loss | ||||
| from ultralytics.yolo.v8.detect.train import Loss | ||||
| 
 | ||||
| 
 | ||||
| # BaseTrainer python usage | ||||
| class SegmentationTrainer(v8.detect.DetectionTrainer): | ||||
| 
 | ||||
|     def __init__(self, config=DEFAULT_CONFIG, overrides={}): | ||||
|     def __init__(self, config=DEFAULT_CONFIG, overrides=None): | ||||
|         if overrides is None: | ||||
|             overrides = {} | ||||
|         overrides["task"] = "segment" | ||||
|         super().__init__(config, overrides) | ||||
| 
 | ||||
|  | ||||
| @ -13,8 +13,7 @@ from ultralytics.yolo.utils import DEFAULT_CONFIG, NUM_THREADS, ops | ||||
| from ultralytics.yolo.utils.checks import check_requirements | ||||
| from ultralytics.yolo.utils.metrics import ConfusionMatrix, SegmentMetrics, box_iou, mask_iou | ||||
| from ultralytics.yolo.utils.plotting import output_to_target, plot_images | ||||
| 
 | ||||
| from ..detect import DetectionValidator | ||||
| from ultralytics.yolo.v8.detect import DetectionValidator | ||||
| 
 | ||||
| 
 | ||||
| class SegmentationValidator(DetectionValidator): | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Glenn Jocher
						Glenn Jocher