mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
test updates, revert Results to CPU
This commit is contained in:
parent
3ea659411b
commit
bfc078b32f
@ -48,7 +48,7 @@ def test_val_classify():
|
||||
|
||||
# Predict checks -------------------------------------------------------------------------------------------------------
|
||||
def test_predict_detect():
|
||||
run(f"yolo predict model={MODEL}.pt source={ROOT / 'assets'} imgsz=32")
|
||||
run(f"yolo predict model={MODEL}.pt source={ROOT / 'assets'} imgsz=32 save")
|
||||
if checks.check_online():
|
||||
run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/images/bus.jpg imgsz=32')
|
||||
run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_landscape_min.mov imgsz=32')
|
||||
@ -56,11 +56,11 @@ def test_predict_detect():
|
||||
|
||||
|
||||
def test_predict_segment():
|
||||
run(f"yolo predict model={MODEL}-seg.pt source={ROOT / 'assets'} imgsz=32")
|
||||
run(f"yolo predict model={MODEL}-seg.pt source={ROOT / 'assets'} imgsz=32 save")
|
||||
|
||||
|
||||
def test_predict_classify():
|
||||
run(f"yolo predict model={MODEL}-cls.pt source={ROOT / 'assets'} imgsz=32")
|
||||
run(f"yolo predict model={MODEL}-cls.pt source={ROOT / 'assets'} imgsz=32 save")
|
||||
|
||||
|
||||
# Export checks --------------------------------------------------------------------------------------------------------
|
||||
|
@ -42,9 +42,9 @@ class Results:
|
||||
def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None) -> None:
|
||||
self.orig_img = orig_img
|
||||
self.orig_shape = orig_img.shape[:2]
|
||||
self.boxes = Boxes(boxes.cpu(), self.orig_shape) if boxes is not None else None # native size boxes
|
||||
self.masks = Masks(masks.cpu(), self.orig_shape) if masks is not None else None # native size or imgsz masks
|
||||
self.probs = probs.cpu() if probs is not None else None
|
||||
self.boxes = Boxes(boxes, self.orig_shape) if boxes is not None else None # native size boxes
|
||||
self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks
|
||||
self.probs = probs if probs is not None else None
|
||||
self.names = names
|
||||
self.path = path
|
||||
self._keys = (k for k in ('boxes', 'masks', 'probs') if getattr(self, k) is not None)
|
||||
|
@ -114,7 +114,9 @@ class Annotator:
|
||||
self.im = np.asarray(self.im).copy()
|
||||
if len(masks) == 0:
|
||||
self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
|
||||
colors = torch.tensor(colors, device=im_gpu.device, dtype=torch.float32) / 255.0
|
||||
if im_gpu.device != masks.device:
|
||||
im_gpu = im_gpu.to(masks.device)
|
||||
colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0
|
||||
colors = colors[:, None, None] # shape(n,1,1,3)
|
||||
masks = masks.unsqueeze(3) # shape(n,h,w,1)
|
||||
masks_color = masks * (colors * alpha) # shape(n,h,w,3)
|
||||
|
Loading…
x
Reference in New Issue
Block a user