From 19c3314e68b47c8c5cb799f709868b083921067a Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 26 Sep 2023 20:28:45 +0200 Subject: [PATCH] `ultralytics 8.0.188` fix .grad attribute leaf Tensor Warning (#5094) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- docs/index.md | 2 +- examples/tutorial.ipynb | 12 +++++++++ tests/test_cuda.py | 29 +++++++++++++++++++-- ultralytics/__init__.py | 2 +- ultralytics/engine/model.py | 19 ++++++-------- ultralytics/engine/validator.py | 1 + ultralytics/models/fastsam/prompt.py | 38 +++++++++++----------------- ultralytics/nn/autobackend.py | 6 +++++ ultralytics/utils/__init__.py | 3 ++- ultralytics/utils/checks.py | 2 +- ultralytics/utils/torch_utils.py | 5 +++- 11 files changed, 78 insertions(+), 41 deletions(-) diff --git a/docs/index.md b/docs/index.md index 55d7c5e9..00951dd2 100644 --- a/docs/index.md +++ b/docs/index.md @@ -6,7 +6,7 @@ keywords: Ultralytics, YOLOv8, object detection, image segmentation, machine lea

- +

Ultralytics CI diff --git a/examples/tutorial.ipynb b/examples/tutorial.ipynb index 12d278ce..19eefd3e 100644 --- a/examples/tutorial.ipynb +++ b/examples/tutorial.ipynb @@ -563,6 +563,18 @@ "Additional content below." ] }, + { + "cell_type": "code", + "source": [ + "# Pip install from source\n", + "!pip install git+https://github.com/ultralytics/ultralytics@main" + ], + "metadata": { + "id": "pIdE6i8C3LYp" + }, + "execution_count": null, + "outputs": [] + }, { "cell_type": "code", "source": [ diff --git a/tests/test_cuda.py b/tests/test_cuda.py index f11f8fa0..52972a09 100644 --- a/tests/test_cuda.py +++ b/tests/test_cuda.py @@ -16,6 +16,7 @@ DATASETS_DIR = Path(SETTINGS['datasets_dir']) WEIGHTS_DIR = Path(SETTINGS['weights_dir']) MODEL = WEIGHTS_DIR / 'path with spaces' / 'yolov8n.pt' # test spaces in path DATA = 'coco8.yaml' +BUS = ASSETS / 'bus.jpg' def test_checks(): @@ -29,6 +30,30 @@ def test_train(): YOLO(MODEL).train(data=DATA, imgsz=64, epochs=1, device=device) # requires imgsz>=64 +@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available') +def test_predict_multiple_devices(): + model = YOLO('yolov8n.pt') + model = model.cpu() + assert str(model.device) == 'cpu' + _ = model(BUS) # CPU inference + assert str(model.device) == 'cpu' + + model = model.to('cuda:0') + assert str(model.device) == 'cuda:0' + _ = model(BUS) # CUDA inference + assert str(model.device) == 'cuda:0' + + model = model.cpu() + assert str(model.device) == 'cpu' + _ = model(BUS) # CPU inference + assert str(model.device) == 'cpu' + + model = model.cuda() + assert str(model.device) == 'cuda:0' + _ = model(BUS) # CUDA inference + assert str(model.device) == 'cuda:0' + + @pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available') def test_autobatch(): from ultralytics.utils.autobatch import check_train_batch_size @@ -57,10 +82,10 @@ def test_predict_sam(): model.info() # Run inference - model(ASSETS / 'bus.jpg', device=0) + model(BUS, device=0) # Run inference with bboxes prompt - model(ASSETS / 'zidane.jpg', bboxes=[439, 437, 524, 709], device=0) + model(BUS, bboxes=[439, 437, 524, 709], device=0) # Run inference with points prompt model(ASSETS / 'zidane.jpg', points=[900, 370], labels=[1], device=0) diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 1312c2d7..76196e34 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = '8.0.187' +__version__ = '8.0.188' from ultralytics.models import RTDETR, SAM, YOLO from ultralytics.models.fastsam import FastSAM diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py index 5973e73f..d8c881a9 100644 --- a/ultralytics/engine/model.py +++ b/ultralytics/engine/model.py @@ -8,8 +8,7 @@ from typing import Union from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir from ultralytics.hub.utils import HUB_WEB_ROOT from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load -from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, callbacks, emojis, yaml_load -from ultralytics.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml +from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, callbacks, checks, emojis, yaml_load from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS @@ -139,7 +138,7 @@ class Model(nn.Module): self.overrides = self.model.args = self._reset_ckpt_args(self.model.args) self.ckpt_path = self.model.pt_path else: - weights = check_file(weights) + weights = checks.check_file(weights) self.model, self.ckpt = weights, None self.task = task or guess_model_task(weights) self.ckpt_path = weights @@ -204,11 +203,11 @@ class Model(nn.Module): Args: source (str | int | PIL | np.ndarray): The source of the image to make predictions on. - Accepts all source types accepted by the YOLO model. + Accepts all source types accepted by the YOLO model. stream (bool): Whether to stream the predictions or not. Defaults to False. predictor (BasePredictor): Customized predictor. **kwargs : Additional keyword arguments passed to the predictor. - Check the 'configuration' section in the documentation for all available options. + Check the 'configuration' section in the documentation for all available options. Returns: (List[ultralytics.engine.results.Results]): The prediction results. @@ -251,8 +250,7 @@ class Model(nn.Module): if not hasattr(self.predictor, 'trackers'): from ultralytics.trackers import register_tracker register_tracker(self, persist) - # ByteTrack-based method needs low confidence predictions as input - kwargs['conf'] = kwargs.get('conf') or 0.1 + kwargs['conf'] = kwargs.get('conf') or 0.1 # ByteTrack-based method needs low confidence predictions as input kwargs['mode'] = 'track' return self.predict(source=source, stream=stream, **kwargs) @@ -266,7 +264,6 @@ class Model(nn.Module): """ custom = {'rect': True} # method defaults args = {**self.overrides, **custom, **kwargs, 'mode': 'val'} # highest priority args on the right - args['imgsz'] = check_imgsz(args['imgsz'], max_dim=1) validator = (validator or self._smart_load('validator'))(args=args, _callbacks=self.callbacks) validator(model=self.model) @@ -321,9 +318,9 @@ class Model(nn.Module): if any(kwargs): LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.') kwargs = self.session.train_args - check_pip_update_available() + checks.check_pip_update_available() - overrides = yaml_load(check_yaml(kwargs['cfg'])) if kwargs.get('cfg') else self.overrides + overrides = yaml_load(checks.check_yaml(kwargs['cfg'])) if kwargs.get('cfg') else self.overrides custom = {'data': TASK2DATA[self.task]} # method defaults args = {**overrides, **custom, **kwargs, 'mode': 'train'} # highest priority args on the right if args.get('resume'): @@ -366,7 +363,7 @@ class Model(nn.Module): self._check_is_pytorch_model() self = super()._apply(fn) # noqa self.predictor = None # reset predictor as device may have changed - self.overrides['device'] = str(self.device) # i.e. device(type='cuda', index=0) -> 'cuda:0' + self.overrides['device'] = self.device # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0' return self @property diff --git a/ultralytics/engine/validator.py b/ultralytics/engine/validator.py index 9730c9be..8d8349bd 100644 --- a/ultralytics/engine/validator.py +++ b/ultralytics/engine/validator.py @@ -95,6 +95,7 @@ class BaseValidator: (self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) if self.args.conf is None: self.args.conf = 0.001 # default conf=0.001 + self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1) self.plots = {} self.callbacks = _callbacks or callbacks.get_default_callbacks() diff --git a/ultralytics/models/fastsam/prompt.py b/ultralytics/models/fastsam/prompt.py index 9d5ae253..97ab46c3 100644 --- a/ultralytics/models/fastsam/prompt.py +++ b/ultralytics/models/fastsam/prompt.py @@ -87,7 +87,7 @@ class FastSAMPrompt: pbar = TQDM(annotations, total=len(annotations)) for ann in pbar: result_name = os.path.basename(ann.path) - image = ann.orig_img + image = ann.orig_img[..., ::-1] # BGR to RGB original_h, original_w = ann.orig_shape # for macOS only # plt.switch_backend('TkAgg') @@ -108,17 +108,15 @@ class FastSAMPrompt: mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)) masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)) - self.fast_show_mask( - masks, - plt.gca(), - random_color=mask_random_color, - bbox=bbox, - points=points, - pointlabel=point_label, - retinamask=retina, - target_height=original_h, - target_width=original_w, - ) + self.fast_show_mask(masks, + plt.gca(), + random_color=mask_random_color, + bbox=bbox, + points=points, + pointlabel=point_label, + retinamask=retina, + target_height=original_h, + target_width=original_w) if with_contours: contour_all = [] @@ -134,17 +132,11 @@ class FastSAMPrompt: contour_mask = temp / 255 * color.reshape(1, 1, -1) plt.imshow(contour_mask) - plt.axis('off') - fig = plt.gcf() - - # Check if the canvas has been drawn - if fig.canvas.get_renderer() is None: # macOS requires this or tests fail - fig.canvas.draw() - + # Save the figure save_path = Path(output) / result_name save_path.parent.mkdir(exist_ok=True, parents=True) - image = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) - image.save(save_path) + plt.axis('off') + plt.savefig(save_path, bbox_inches='tight', pad_inches=0, transparent=True) plt.close() pbar.set_description(f'Saving {result_name} to {save_path}') @@ -263,8 +255,8 @@ class FastSAMPrompt: orig_masks_area = torch.sum(masks, dim=(1, 2)) union = bbox_area + orig_masks_area - masks_area - IoUs = masks_area / union - max_iou_index = torch.argmax(IoUs) + iou = masks_area / union + max_iou_index = torch.argmax(iou) self.results[0].masks.data = torch.tensor(np.array([masks[max_iou_index].cpu().numpy()])) return self.results diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py index 9010815b..633aa9db 100644 --- a/ultralytics/nn/autobackend.py +++ b/ultralytics/nn/autobackend.py @@ -39,6 +39,7 @@ def check_class_names(names): class AutoBackend(nn.Module): + @torch.no_grad() def __init__(self, weights='yolov8n.pt', device=torch.device('cpu'), @@ -309,6 +310,11 @@ class AutoBackend(nn.Module): names = self._apply_default_class_names(data) names = check_class_names(names) + # Disable gradients + if pt: + for p in model.parameters(): + p.requires_grad = False + self.__dict__.update(locals()) # assign all variables to self def forward(self, im, augment=False, visualize=False): diff --git a/ultralytics/utils/__init__.py b/ultralytics/utils/__init__.py index bcf4c805..4313c263 100644 --- a/ultralytics/utils/__init__.py +++ b/ultralytics/utils/__init__.py @@ -327,8 +327,9 @@ def yaml_save(file='data.yaml', data=None, header=''): file.parent.mkdir(parents=True, exist_ok=True) # Convert Path objects to strings + valid_types = int, float, str, bool, list, tuple, dict, type(None) for k, v in data.items(): - if isinstance(v, Path): + if not isinstance(v, valid_types): data[k] = str(v) # Dump data to file in YAML format diff --git a/ultralytics/utils/checks.py b/ultralytics/utils/checks.py index cbb90365..5df86754 100644 --- a/ultralytics/utils/checks.py +++ b/ultralytics/utils/checks.py @@ -55,7 +55,7 @@ def parse_requirements(file_path=ROOT.parent / 'requirements.txt', package=''): line = line.strip() if line and not line.startswith('#'): line = line.split('#')[0].strip() # ignore inline comments - match = re.match(r'([a-zA-Z0-9-_]+)([<>!=~]+.*)?', line) + match = re.match(r'([a-zA-Z0-9-_]+)\s*([<>!=~]+.*)?', line) if match: requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else '')) diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py index 9a35946a..b9774576 100644 --- a/ultralytics/utils/torch_utils.py +++ b/ultralytics/utils/torch_utils.py @@ -44,7 +44,10 @@ def smart_inference_mode(): def decorate(fn): """Applies appropriate torch decorator for inference mode based on torch version.""" - return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn) + if TORCH_1_9 and torch.is_inference_mode_enabled(): + return fn # already in inference_mode, act as a pass-through + else: + return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn) return decorate