mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
ultralytics 8.0.170
apply is_list
fixes for torch.Tensor inputs (#4713)
Co-authored-by: Gezhi Zhang <765724965@qq.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
a1c1d6b483
commit
aa9133bb88
@ -45,6 +45,10 @@ keywords: Ultralytics, data utils, YOLO, img2label_paths, exif_size, polygon2mas
|
|||||||
## ::: ultralytics.data.utils.polygons2masks_overlap
|
## ::: ultralytics.data.utils.polygons2masks_overlap
|
||||||
<br><br>
|
<br><br>
|
||||||
|
|
||||||
|
---
|
||||||
|
## ::: ultralytics.data.utils.find_dataset_yaml
|
||||||
|
<br><br>
|
||||||
|
|
||||||
---
|
---
|
||||||
## ::: ultralytics.data.utils.check_det_dataset
|
## ::: ultralytics.data.utils.check_det_dataset
|
||||||
<br><br>
|
<br><br>
|
||||||
|
@ -9,6 +9,10 @@ keywords: Ultralytics, Utils, utilitarian functions, colorstr, yaml_save, set_lo
|
|||||||
|
|
||||||
Full source code for this file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/__init__.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/__init__.py). Help us fix any issues you see by submitting a [Pull Request](https://docs.ultralytics.com/help/contributing/) 🛠️. Thank you 🙏!
|
Full source code for this file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/__init__.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/__init__.py). Help us fix any issues you see by submitting a [Pull Request](https://docs.ultralytics.com/help/contributing/) 🛠️. Thank you 🙏!
|
||||||
|
|
||||||
|
---
|
||||||
|
## ::: ultralytics.utils.TQDM
|
||||||
|
<br><br>
|
||||||
|
|
||||||
---
|
---
|
||||||
## ::: ultralytics.utils.SimpleClass
|
## ::: ultralytics.utils.SimpleClass
|
||||||
<br><br>
|
<br><br>
|
||||||
|
@ -117,6 +117,10 @@ keywords: Ultralytics YOLO, Utility Operations, segment2box, make_divisible, cli
|
|||||||
## ::: ultralytics.utils.ops.masks2segments
|
## ::: ultralytics.utils.ops.masks2segments
|
||||||
<br><br>
|
<br><br>
|
||||||
|
|
||||||
|
---
|
||||||
|
## ::: ultralytics.utils.ops.convert_torch2numpy_batch
|
||||||
|
<br><br>
|
||||||
|
|
||||||
---
|
---
|
||||||
## ::: ultralytics.utils.ops.clean_str
|
## ::: ultralytics.utils.ops.clean_str
|
||||||
<br><br>
|
<br><br>
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
__version__ = '8.0.169'
|
__version__ = '8.0.170'
|
||||||
|
|
||||||
from ultralytics.models import RTDETR, SAM, YOLO
|
from ultralytics.models import RTDETR, SAM, YOLO
|
||||||
from ultralytics.models.fastsam import FastSAM
|
from ultralytics.models.fastsam import FastSAM
|
||||||
|
@ -205,7 +205,7 @@ class Results(SimpleClass):
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
if img is None and isinstance(self.orig_img, torch.Tensor):
|
if img is None and isinstance(self.orig_img, torch.Tensor):
|
||||||
img = (self.orig_img[0].detach().permute(1, 2, 0).cpu().contiguous() * 255).to(torch.uint8).numpy()
|
img = (self.orig_img[0].detach().permute(1, 2, 0).contiguous() * 255).to(torch.uint8).cpu().numpy()
|
||||||
|
|
||||||
# Deprecation warn TODO: remove in 8.2
|
# Deprecation warn TODO: remove in 8.2
|
||||||
if 'show_conf' in kwargs:
|
if 'show_conf' in kwargs:
|
||||||
|
@ -30,21 +30,22 @@ class FastSAMPredictor(DetectionPredictor):
|
|||||||
full_box[0][4] = p[0][critical_iou_index][:, 4]
|
full_box[0][4] = p[0][critical_iou_index][:, 4]
|
||||||
full_box[0][6:] = p[0][critical_iou_index][:, 6:]
|
full_box[0][6:] = p[0][critical_iou_index][:, 6:]
|
||||||
p[0][critical_iou_index] = full_box
|
p[0][critical_iou_index] = full_box
|
||||||
|
|
||||||
|
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
||||||
|
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
is_list = isinstance(orig_imgs, list) # input images are a list, not a torch.Tensor
|
|
||||||
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
|
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
|
||||||
for i, pred in enumerate(p):
|
for i, pred in enumerate(p):
|
||||||
orig_img = orig_imgs[i] if is_list else orig_imgs
|
orig_img = orig_imgs[i]
|
||||||
img_path = self.batch[0][i]
|
img_path = self.batch[0][i]
|
||||||
if not len(pred): # save empty boxes
|
if not len(pred): # save empty boxes
|
||||||
masks = None
|
masks = None
|
||||||
elif self.args.retina_masks:
|
elif self.args.retina_masks:
|
||||||
if is_list:
|
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
||||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
|
||||||
masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
|
masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
|
||||||
else:
|
else:
|
||||||
masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
|
masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
|
||||||
if is_list:
|
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
||||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
|
||||||
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
|
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
|
||||||
return results
|
return results
|
||||||
|
@ -23,12 +23,13 @@ class NASPredictor(BasePredictor):
|
|||||||
max_det=self.args.max_det,
|
max_det=self.args.max_det,
|
||||||
classes=self.args.classes)
|
classes=self.args.classes)
|
||||||
|
|
||||||
|
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
||||||
|
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
is_list = isinstance(orig_imgs, list) # input images are a list, not a torch.Tensor
|
|
||||||
for i, pred in enumerate(preds):
|
for i, pred in enumerate(preds):
|
||||||
orig_img = orig_imgs[i] if is_list else orig_imgs
|
orig_img = orig_imgs[i]
|
||||||
if is_list:
|
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
||||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
|
||||||
img_path = self.batch[0][i]
|
img_path = self.batch[0][i]
|
||||||
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
|
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
|
||||||
return results
|
return results
|
||||||
|
@ -27,8 +27,11 @@ class RTDETRPredictor(BasePredictor):
|
|||||||
"""Postprocess predictions and returns a list of Results objects."""
|
"""Postprocess predictions and returns a list of Results objects."""
|
||||||
nd = preds[0].shape[-1]
|
nd = preds[0].shape[-1]
|
||||||
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
|
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
|
||||||
|
|
||||||
|
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
||||||
|
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
is_list = isinstance(orig_imgs, list) # input images are a list, not a torch.Tensor
|
|
||||||
for i, bbox in enumerate(bboxes): # (300, 4)
|
for i, bbox in enumerate(bboxes): # (300, 4)
|
||||||
bbox = ops.xywh2xyxy(bbox)
|
bbox = ops.xywh2xyxy(bbox)
|
||||||
score, cls = scores[i].max(-1, keepdim=True) # (300, 1)
|
score, cls = scores[i].max(-1, keepdim=True) # (300, 1)
|
||||||
@ -36,11 +39,10 @@ class RTDETRPredictor(BasePredictor):
|
|||||||
if self.args.classes is not None:
|
if self.args.classes is not None:
|
||||||
idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx
|
idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx
|
||||||
pred = torch.cat([bbox, score, cls], dim=-1)[idx] # filter
|
pred = torch.cat([bbox, score, cls], dim=-1)[idx] # filter
|
||||||
orig_img = orig_imgs[i] if is_list else orig_imgs
|
orig_img = orig_imgs[i]
|
||||||
oh, ow = orig_img.shape[:2]
|
oh, ow = orig_img.shape[:2]
|
||||||
if is_list:
|
pred[..., [0, 2]] *= ow
|
||||||
pred[..., [0, 2]] *= ow
|
pred[..., [1, 3]] *= oh
|
||||||
pred[..., [1, 3]] *= oh
|
|
||||||
img_path = self.batch[0][i]
|
img_path = self.batch[0][i]
|
||||||
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
|
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
|
||||||
return results
|
return results
|
||||||
|
@ -312,10 +312,13 @@ class Predictor(BasePredictor):
|
|||||||
pred_masks, pred_scores = preds[:2]
|
pred_masks, pred_scores = preds[:2]
|
||||||
pred_bboxes = preds[2] if self.segment_all else None
|
pred_bboxes = preds[2] if self.segment_all else None
|
||||||
names = dict(enumerate(str(i) for i in range(len(pred_masks))))
|
names = dict(enumerate(str(i) for i in range(len(pred_masks))))
|
||||||
|
|
||||||
|
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
||||||
|
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
is_list = isinstance(orig_imgs, list) # input images are a list, not a torch.Tensor
|
|
||||||
for i, masks in enumerate([pred_masks]):
|
for i, masks in enumerate([pred_masks]):
|
||||||
orig_img = orig_imgs[i] if is_list else orig_imgs
|
orig_img = orig_imgs[i]
|
||||||
if pred_bboxes is not None:
|
if pred_bboxes is not None:
|
||||||
pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False)
|
pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False)
|
||||||
cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device)
|
cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device)
|
||||||
|
@ -4,7 +4,7 @@ import torch
|
|||||||
|
|
||||||
from ultralytics.engine.predictor import BasePredictor
|
from ultralytics.engine.predictor import BasePredictor
|
||||||
from ultralytics.engine.results import Results
|
from ultralytics.engine.results import Results
|
||||||
from ultralytics.utils import DEFAULT_CFG
|
from ultralytics.utils import DEFAULT_CFG, ops
|
||||||
|
|
||||||
|
|
||||||
class ClassificationPredictor(BasePredictor):
|
class ClassificationPredictor(BasePredictor):
|
||||||
@ -38,10 +38,12 @@ class ClassificationPredictor(BasePredictor):
|
|||||||
|
|
||||||
def postprocess(self, preds, img, orig_imgs):
|
def postprocess(self, preds, img, orig_imgs):
|
||||||
"""Post-processes predictions to return Results objects."""
|
"""Post-processes predictions to return Results objects."""
|
||||||
|
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
||||||
|
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
is_list = isinstance(orig_imgs, list) # input images are a list, not a torch.Tensor
|
|
||||||
for i, pred in enumerate(preds):
|
for i, pred in enumerate(preds):
|
||||||
orig_img = orig_imgs[i] if is_list else orig_imgs
|
orig_img = orig_imgs[i]
|
||||||
img_path = self.batch[0][i]
|
img_path = self.batch[0][i]
|
||||||
results.append(Results(orig_img, path=img_path, names=self.model.names, probs=pred))
|
results.append(Results(orig_img, path=img_path, names=self.model.names, probs=pred))
|
||||||
return results
|
return results
|
||||||
|
@ -29,12 +29,13 @@ class DetectionPredictor(BasePredictor):
|
|||||||
max_det=self.args.max_det,
|
max_det=self.args.max_det,
|
||||||
classes=self.args.classes)
|
classes=self.args.classes)
|
||||||
|
|
||||||
|
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
||||||
|
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
is_list = isinstance(orig_imgs, list) # input images are a list, not a torch.Tensor
|
|
||||||
for i, pred in enumerate(preds):
|
for i, pred in enumerate(preds):
|
||||||
orig_img = orig_imgs[i] if is_list else orig_imgs
|
orig_img = orig_imgs[i]
|
||||||
if is_list:
|
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
||||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
|
||||||
img_path = self.batch[0][i]
|
img_path = self.batch[0][i]
|
||||||
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
|
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
|
||||||
return results
|
return results
|
||||||
|
@ -37,10 +37,12 @@ class PosePredictor(DetectionPredictor):
|
|||||||
classes=self.args.classes,
|
classes=self.args.classes,
|
||||||
nc=len(self.model.names))
|
nc=len(self.model.names))
|
||||||
|
|
||||||
|
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
||||||
|
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
is_list = isinstance(orig_imgs, list) # input images are a list, not a torch.Tensor
|
|
||||||
for i, pred in enumerate(preds):
|
for i, pred in enumerate(preds):
|
||||||
orig_img = orig_imgs[i] if is_list else orig_imgs
|
orig_img = orig_imgs[i]
|
||||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape).round()
|
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape).round()
|
||||||
pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:]
|
pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:]
|
||||||
pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
|
pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
|
||||||
|
@ -32,21 +32,22 @@ class SegmentationPredictor(DetectionPredictor):
|
|||||||
max_det=self.args.max_det,
|
max_det=self.args.max_det,
|
||||||
nc=len(self.model.names),
|
nc=len(self.model.names),
|
||||||
classes=self.args.classes)
|
classes=self.args.classes)
|
||||||
|
|
||||||
|
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
||||||
|
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
is_list = isinstance(orig_imgs, list) # input images are a list, not a torch.Tensor
|
|
||||||
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
|
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
|
||||||
for i, pred in enumerate(p):
|
for i, pred in enumerate(p):
|
||||||
orig_img = orig_imgs[i] if is_list else orig_imgs
|
orig_img = orig_imgs[i]
|
||||||
img_path = self.batch[0][i]
|
img_path = self.batch[0][i]
|
||||||
if not len(pred): # save empty boxes
|
if not len(pred): # save empty boxes
|
||||||
masks = None
|
masks = None
|
||||||
elif self.args.retina_masks:
|
elif self.args.retina_masks:
|
||||||
if is_list:
|
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
||||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
|
||||||
masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
|
masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
|
||||||
else:
|
else:
|
||||||
masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
|
masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
|
||||||
if is_list:
|
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
||||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
|
||||||
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
|
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
|
||||||
return results
|
return results
|
||||||
|
@ -112,8 +112,8 @@ class TQDM(tqdm_original):
|
|||||||
Custom Ultralytics tqdm class with different default arguments.
|
Custom Ultralytics tqdm class with different default arguments.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
(*args): Positional arguments passed to original tqdm.
|
*args (list): Positional arguments passed to original tqdm.
|
||||||
(**kwargs): Keyword arguments, with custom defaults applied.
|
**kwargs (dict): Keyword arguments, with custom defaults applied.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
@ -771,6 +771,19 @@ def masks2segments(masks, strategy='largest'):
|
|||||||
return segments
|
return segments
|
||||||
|
|
||||||
|
|
||||||
|
def convert_torch2numpy_batch(batch: torch.Tensor) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Convert a batch of FP32 torch tensors (0.0-1.0) to a NumPy uint8 array (0-255), changing from BCHW to BHWC layout.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (torch.Tensor): Input tensor batch of shape (Batch, Channels, Height, Width) and dtype torch.float32.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(np.ndarray): Output NumPy array batch of shape (Batch, Height, Width, Channels) and dtype uint8.
|
||||||
|
"""
|
||||||
|
return (batch.permute(0, 2, 3, 1).contiguous() * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
|
||||||
|
|
||||||
|
|
||||||
def clean_str(s):
|
def clean_str(s):
|
||||||
"""
|
"""
|
||||||
Cleans a string by replacing special characters with underscore _
|
Cleans a string by replacing special characters with underscore _
|
||||||
|
Loading…
x
Reference in New Issue
Block a user