diff --git a/docs/index.md b/docs/index.md
index 55d7c5e9..00951dd2 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -6,7 +6,7 @@ keywords: Ultralytics, YOLOv8, object detection, image segmentation, machine lea
-
+

diff --git a/examples/tutorial.ipynb b/examples/tutorial.ipynb
index 12d278ce..19eefd3e 100644
--- a/examples/tutorial.ipynb
+++ b/examples/tutorial.ipynb
@@ -563,6 +563,18 @@
"Additional content below."
]
},
+ {
+ "cell_type": "code",
+ "source": [
+ "# Pip install from source\n",
+ "!pip install git+https://github.com/ultralytics/ultralytics@main"
+ ],
+ "metadata": {
+ "id": "pIdE6i8C3LYp"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
{
"cell_type": "code",
"source": [
diff --git a/tests/test_cuda.py b/tests/test_cuda.py
index f11f8fa0..52972a09 100644
--- a/tests/test_cuda.py
+++ b/tests/test_cuda.py
@@ -16,6 +16,7 @@ DATASETS_DIR = Path(SETTINGS['datasets_dir'])
WEIGHTS_DIR = Path(SETTINGS['weights_dir'])
MODEL = WEIGHTS_DIR / 'path with spaces' / 'yolov8n.pt' # test spaces in path
DATA = 'coco8.yaml'
+BUS = ASSETS / 'bus.jpg'
def test_checks():
@@ -29,6 +30,30 @@ def test_train():
YOLO(MODEL).train(data=DATA, imgsz=64, epochs=1, device=device) # requires imgsz>=64
+@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
+def test_predict_multiple_devices():
+ model = YOLO('yolov8n.pt')
+ model = model.cpu()
+ assert str(model.device) == 'cpu'
+ _ = model(BUS) # CPU inference
+ assert str(model.device) == 'cpu'
+
+ model = model.to('cuda:0')
+ assert str(model.device) == 'cuda:0'
+ _ = model(BUS) # CUDA inference
+ assert str(model.device) == 'cuda:0'
+
+ model = model.cpu()
+ assert str(model.device) == 'cpu'
+ _ = model(BUS) # CPU inference
+ assert str(model.device) == 'cpu'
+
+ model = model.cuda()
+ assert str(model.device) == 'cuda:0'
+ _ = model(BUS) # CUDA inference
+ assert str(model.device) == 'cuda:0'
+
+
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
def test_autobatch():
from ultralytics.utils.autobatch import check_train_batch_size
@@ -57,10 +82,10 @@ def test_predict_sam():
model.info()
# Run inference
- model(ASSETS / 'bus.jpg', device=0)
+ model(BUS, device=0)
# Run inference with bboxes prompt
- model(ASSETS / 'zidane.jpg', bboxes=[439, 437, 524, 709], device=0)
+ model(BUS, bboxes=[439, 437, 524, 709], device=0)
# Run inference with points prompt
model(ASSETS / 'zidane.jpg', points=[900, 370], labels=[1], device=0)
diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py
index 1312c2d7..76196e34 100644
--- a/ultralytics/__init__.py
+++ b/ultralytics/__init__.py
@@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
-__version__ = '8.0.187'
+__version__ = '8.0.188'
from ultralytics.models import RTDETR, SAM, YOLO
from ultralytics.models.fastsam import FastSAM
diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py
index 5973e73f..d8c881a9 100644
--- a/ultralytics/engine/model.py
+++ b/ultralytics/engine/model.py
@@ -8,8 +8,7 @@ from typing import Union
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
from ultralytics.hub.utils import HUB_WEB_ROOT
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
-from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, callbacks, emojis, yaml_load
-from ultralytics.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml
+from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, callbacks, checks, emojis, yaml_load
from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS
@@ -139,7 +138,7 @@ class Model(nn.Module):
self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
self.ckpt_path = self.model.pt_path
else:
- weights = check_file(weights)
+ weights = checks.check_file(weights)
self.model, self.ckpt = weights, None
self.task = task or guess_model_task(weights)
self.ckpt_path = weights
@@ -204,11 +203,11 @@ class Model(nn.Module):
Args:
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
- Accepts all source types accepted by the YOLO model.
+ Accepts all source types accepted by the YOLO model.
stream (bool): Whether to stream the predictions or not. Defaults to False.
predictor (BasePredictor): Customized predictor.
**kwargs : Additional keyword arguments passed to the predictor.
- Check the 'configuration' section in the documentation for all available options.
+ Check the 'configuration' section in the documentation for all available options.
Returns:
(List[ultralytics.engine.results.Results]): The prediction results.
@@ -251,8 +250,7 @@ class Model(nn.Module):
if not hasattr(self.predictor, 'trackers'):
from ultralytics.trackers import register_tracker
register_tracker(self, persist)
- # ByteTrack-based method needs low confidence predictions as input
- kwargs['conf'] = kwargs.get('conf') or 0.1
+ kwargs['conf'] = kwargs.get('conf') or 0.1 # ByteTrack-based method needs low confidence predictions as input
kwargs['mode'] = 'track'
return self.predict(source=source, stream=stream, **kwargs)
@@ -266,7 +264,6 @@ class Model(nn.Module):
"""
custom = {'rect': True} # method defaults
args = {**self.overrides, **custom, **kwargs, 'mode': 'val'} # highest priority args on the right
- args['imgsz'] = check_imgsz(args['imgsz'], max_dim=1)
validator = (validator or self._smart_load('validator'))(args=args, _callbacks=self.callbacks)
validator(model=self.model)
@@ -321,9 +318,9 @@ class Model(nn.Module):
if any(kwargs):
LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.')
kwargs = self.session.train_args
- check_pip_update_available()
+ checks.check_pip_update_available()
- overrides = yaml_load(check_yaml(kwargs['cfg'])) if kwargs.get('cfg') else self.overrides
+ overrides = yaml_load(checks.check_yaml(kwargs['cfg'])) if kwargs.get('cfg') else self.overrides
custom = {'data': TASK2DATA[self.task]} # method defaults
args = {**overrides, **custom, **kwargs, 'mode': 'train'} # highest priority args on the right
if args.get('resume'):
@@ -366,7 +363,7 @@ class Model(nn.Module):
self._check_is_pytorch_model()
self = super()._apply(fn) # noqa
self.predictor = None # reset predictor as device may have changed
- self.overrides['device'] = str(self.device) # i.e. device(type='cuda', index=0) -> 'cuda:0'
+ self.overrides['device'] = self.device # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0'
return self
@property
diff --git a/ultralytics/engine/validator.py b/ultralytics/engine/validator.py
index 9730c9be..8d8349bd 100644
--- a/ultralytics/engine/validator.py
+++ b/ultralytics/engine/validator.py
@@ -95,6 +95,7 @@ class BaseValidator:
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
if self.args.conf is None:
self.args.conf = 0.001 # default conf=0.001
+ self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1)
self.plots = {}
self.callbacks = _callbacks or callbacks.get_default_callbacks()
diff --git a/ultralytics/models/fastsam/prompt.py b/ultralytics/models/fastsam/prompt.py
index 9d5ae253..97ab46c3 100644
--- a/ultralytics/models/fastsam/prompt.py
+++ b/ultralytics/models/fastsam/prompt.py
@@ -87,7 +87,7 @@ class FastSAMPrompt:
pbar = TQDM(annotations, total=len(annotations))
for ann in pbar:
result_name = os.path.basename(ann.path)
- image = ann.orig_img
+ image = ann.orig_img[..., ::-1] # BGR to RGB
original_h, original_w = ann.orig_shape
# for macOS only
# plt.switch_backend('TkAgg')
@@ -108,17 +108,15 @@ class FastSAMPrompt:
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
- self.fast_show_mask(
- masks,
- plt.gca(),
- random_color=mask_random_color,
- bbox=bbox,
- points=points,
- pointlabel=point_label,
- retinamask=retina,
- target_height=original_h,
- target_width=original_w,
- )
+ self.fast_show_mask(masks,
+ plt.gca(),
+ random_color=mask_random_color,
+ bbox=bbox,
+ points=points,
+ pointlabel=point_label,
+ retinamask=retina,
+ target_height=original_h,
+ target_width=original_w)
if with_contours:
contour_all = []
@@ -134,17 +132,11 @@ class FastSAMPrompt:
contour_mask = temp / 255 * color.reshape(1, 1, -1)
plt.imshow(contour_mask)
- plt.axis('off')
- fig = plt.gcf()
-
- # Check if the canvas has been drawn
- if fig.canvas.get_renderer() is None: # macOS requires this or tests fail
- fig.canvas.draw()
-
+ # Save the figure
save_path = Path(output) / result_name
save_path.parent.mkdir(exist_ok=True, parents=True)
- image = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
- image.save(save_path)
+ plt.axis('off')
+ plt.savefig(save_path, bbox_inches='tight', pad_inches=0, transparent=True)
plt.close()
pbar.set_description(f'Saving {result_name} to {save_path}')
@@ -263,8 +255,8 @@ class FastSAMPrompt:
orig_masks_area = torch.sum(masks, dim=(1, 2))
union = bbox_area + orig_masks_area - masks_area
- IoUs = masks_area / union
- max_iou_index = torch.argmax(IoUs)
+ iou = masks_area / union
+ max_iou_index = torch.argmax(iou)
self.results[0].masks.data = torch.tensor(np.array([masks[max_iou_index].cpu().numpy()]))
return self.results
diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py
index 9010815b..633aa9db 100644
--- a/ultralytics/nn/autobackend.py
+++ b/ultralytics/nn/autobackend.py
@@ -39,6 +39,7 @@ def check_class_names(names):
class AutoBackend(nn.Module):
+ @torch.no_grad()
def __init__(self,
weights='yolov8n.pt',
device=torch.device('cpu'),
@@ -309,6 +310,11 @@ class AutoBackend(nn.Module):
names = self._apply_default_class_names(data)
names = check_class_names(names)
+ # Disable gradients
+ if pt:
+ for p in model.parameters():
+ p.requires_grad = False
+
self.__dict__.update(locals()) # assign all variables to self
def forward(self, im, augment=False, visualize=False):
diff --git a/ultralytics/utils/__init__.py b/ultralytics/utils/__init__.py
index bcf4c805..4313c263 100644
--- a/ultralytics/utils/__init__.py
+++ b/ultralytics/utils/__init__.py
@@ -327,8 +327,9 @@ def yaml_save(file='data.yaml', data=None, header=''):
file.parent.mkdir(parents=True, exist_ok=True)
# Convert Path objects to strings
+ valid_types = int, float, str, bool, list, tuple, dict, type(None)
for k, v in data.items():
- if isinstance(v, Path):
+ if not isinstance(v, valid_types):
data[k] = str(v)
# Dump data to file in YAML format
diff --git a/ultralytics/utils/checks.py b/ultralytics/utils/checks.py
index cbb90365..5df86754 100644
--- a/ultralytics/utils/checks.py
+++ b/ultralytics/utils/checks.py
@@ -55,7 +55,7 @@ def parse_requirements(file_path=ROOT.parent / 'requirements.txt', package=''):
line = line.strip()
if line and not line.startswith('#'):
line = line.split('#')[0].strip() # ignore inline comments
- match = re.match(r'([a-zA-Z0-9-_]+)([<>!=~]+.*)?', line)
+ match = re.match(r'([a-zA-Z0-9-_]+)\s*([<>!=~]+.*)?', line)
if match:
requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else ''))
diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py
index 9a35946a..b9774576 100644
--- a/ultralytics/utils/torch_utils.py
+++ b/ultralytics/utils/torch_utils.py
@@ -44,7 +44,10 @@ def smart_inference_mode():
def decorate(fn):
"""Applies appropriate torch decorator for inference mode based on torch version."""
- return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn)
+ if TORCH_1_9 and torch.is_inference_mode_enabled():
+ return fn # already in inference_mode, act as a pass-through
+ else:
+ return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn)
return decorate