mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Fix model re-fuse() in inference loops (#466)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
parent
cc3c774bde
commit
a86218b767
@ -1,6 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
__version__ = "8.0.8"
|
||||
__version__ = "8.0.9"
|
||||
|
||||
from ultralytics.yolo.engine.model import YOLO
|
||||
from ultralytics.yolo.utils import ops
|
||||
|
@ -63,7 +63,8 @@ class BaseModel(nn.Module):
|
||||
|
||||
def _profile_one_layer(self, m, x, dt):
|
||||
"""
|
||||
Profile the computation time and FLOPs of a single layer of the model on a given input. Appends the results to the provided list.
|
||||
Profile the computation time and FLOPs of a single layer of the model on a given input.
|
||||
Appends the results to the provided list.
|
||||
|
||||
Args:
|
||||
m (nn.Module): The layer to be profiled.
|
||||
@ -74,10 +75,10 @@ class BaseModel(nn.Module):
|
||||
None
|
||||
"""
|
||||
c = m == self.model[-1] # is final layer, copy input as inplace fix
|
||||
o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
|
||||
o = thop.profile(m, inputs=(x.clone() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
|
||||
t = time_sync()
|
||||
for _ in range(10):
|
||||
m(x.copy() if c else x)
|
||||
m(x.clone() if c else x)
|
||||
dt.append((time_sync() - t) * 100)
|
||||
if m == self.model[0]:
|
||||
LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
|
||||
@ -87,20 +88,36 @@ class BaseModel(nn.Module):
|
||||
|
||||
def fuse(self):
|
||||
"""
|
||||
Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer, in order to improve the computation efficiency.
|
||||
Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer, in order to improve the
|
||||
computation efficiency.
|
||||
|
||||
Returns:
|
||||
(nn.Module): The fused model is returned.
|
||||
"""
|
||||
LOGGER.info('Fusing layers... ')
|
||||
if not self.is_fused():
|
||||
LOGGER.info('Fusing... ')
|
||||
for m in self.model.modules():
|
||||
if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
|
||||
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
||||
delattr(m, 'bn') # remove batchnorm
|
||||
m.forward = m.forward_fuse # update forward
|
||||
self.info()
|
||||
|
||||
return self
|
||||
|
||||
def is_fused(self, thresh=10):
|
||||
"""
|
||||
Check if the model has less than a certain threshold of BatchNorm layers.
|
||||
|
||||
Args:
|
||||
thresh (int, optional): The threshold number of BatchNorm layers. Default is 10.
|
||||
|
||||
Returns:
|
||||
bool: True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
|
||||
"""
|
||||
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
|
||||
return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
|
||||
|
||||
def info(self, verbose=False, imgsz=640):
|
||||
"""
|
||||
Prints model information
|
||||
|
@ -1,7 +1,9 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from ultralytics import __version__, yolo
|
||||
@ -17,7 +19,7 @@ CLI_HELP_MSG = \
|
||||
|
||||
pip install ultralytics
|
||||
|
||||
2. Train, Val, Predict and Export using 'yolo' commands of the form:
|
||||
2. Train, Val, Predict and Export using 'yolo' commands:
|
||||
|
||||
yolo TASK MODE ARGS
|
||||
|
||||
@ -97,9 +99,14 @@ def entrypoint():
|
||||
It uses the package's default config and initializes it using the passed overrides.
|
||||
Then it calls the CLI function with the composed config
|
||||
"""
|
||||
if len(sys.argv) == 1: # no arguments passed
|
||||
LOGGER.info(CLI_HELP_MSG)
|
||||
return
|
||||
|
||||
parser = argparse.ArgumentParser(description='YOLO parser')
|
||||
parser.add_argument('args', type=str, nargs='+', help='YOLO args')
|
||||
args = parser.parse_args().args
|
||||
args = re.sub(r'\s*=\s*', '=', ' '.join(args)).split(' ') # remove whitespaces around = sign
|
||||
|
||||
tasks = 'detect', 'segment', 'classify'
|
||||
modes = 'train', 'val', 'predict', 'export'
|
||||
|
@ -8,7 +8,7 @@ mode: "train" # choices=['train', 'val', 'predict'] # mode to run task in.
|
||||
model: null # i.e. yolov8n.pt, yolov8n.yaml. Path to model file
|
||||
data: null # i.e. coco128.yaml. Path to data file
|
||||
epochs: 100 # number of epochs to train for
|
||||
patience: 50 # TODO: epochs to wait for no observable improvement for early stopping of training
|
||||
patience: 50 # epochs to wait for no observable improvement for early stopping of training
|
||||
batch: 16 # number of images per batch
|
||||
imgsz: 640 # size of input images
|
||||
save: True # save checkpoints
|
||||
|
@ -28,10 +28,9 @@ names:
|
||||
# Download script/URL (optional) ---------------------------------------------------------------------------------------
|
||||
download: |
|
||||
import json
|
||||
|
||||
from tqdm import tqdm
|
||||
from utils.general import download, Path
|
||||
|
||||
from ultralytics.yolo.utils.downloads import download
|
||||
from pathlib import Path
|
||||
|
||||
def argoverse2yolo(set):
|
||||
labels = {}
|
||||
|
@ -32,8 +32,8 @@ names:
|
||||
|
||||
# Download script/URL (optional) ---------------------------------------------------------------------------------------
|
||||
download: |
|
||||
from utils.general import download, Path
|
||||
|
||||
from ultralytics.yolo.utils.downloads import download
|
||||
from pathlib import Path
|
||||
|
||||
# Download
|
||||
dir = Path(yaml['path']) # dataset root dir
|
||||
|
@ -386,7 +386,12 @@ names:
|
||||
download: |
|
||||
from tqdm import tqdm
|
||||
|
||||
from utils.general import Path, check_requirements, download, np, xyxy2xywhn
|
||||
from ultralytics.yolo.utils.checks import check_requirements
|
||||
from ultralytics.yolo.utils.downloads import download
|
||||
from ultralytics.yolo.utils.ops import xyxy2xywhn
|
||||
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
check_requirements(('pycocotools>=2.0',))
|
||||
from pycocotools.coco import COCO
|
||||
|
@ -21,9 +21,14 @@ names:
|
||||
# Download script/URL (optional) ---------------------------------------------------------------------------------------
|
||||
download: |
|
||||
import shutil
|
||||
from tqdm import tqdm
|
||||
from utils.general import np, pd, Path, download, xyxy2xywh
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics.yolo.utils.downloads import download
|
||||
from ultralytics.yolo.utils.ops import xyxy2xywh
|
||||
|
||||
# Download
|
||||
dir = Path(yaml['path']) # dataset root dir
|
||||
|
@ -48,8 +48,8 @@ download: |
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from tqdm import tqdm
|
||||
from utils.general import download, Path
|
||||
|
||||
from ultralytics.yolo.utils.downloads import download
|
||||
from pathlib import Path
|
||||
|
||||
def convert_label(path, lb_path, year, image_id):
|
||||
def convert_box(size, box):
|
||||
|
@ -29,7 +29,10 @@ names:
|
||||
|
||||
# Download script/URL (optional) ---------------------------------------------------------------------------------------
|
||||
download: |
|
||||
from utils.general import download, os, Path
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from ultralytics.yolo.utils.downloads import download
|
||||
|
||||
def visdrone2yolo(dir):
|
||||
from PIL import Image
|
||||
|
@ -99,7 +99,9 @@ names:
|
||||
|
||||
# Download script/URL (optional)
|
||||
download: |
|
||||
from utils.general import download, Path
|
||||
from ultralytics.yoloutils.downloads import download
|
||||
from pathlib import Path
|
||||
|
||||
# Download labels
|
||||
segments = True # segment or box labels
|
||||
dir = Path(yaml['path']) # dataset root dir
|
||||
|
@ -87,8 +87,8 @@ download: |
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from utils.dataloaders import autosplit
|
||||
from utils.general import download, xyxy2xywhn
|
||||
from ultralytics.yolo.data.dataloaders.v5loader import autosplit
|
||||
from ultralytics.yolo.utils.ops import xyxy2xywhn
|
||||
|
||||
|
||||
def convert_labels(fname=Path('xView/xView_train.geojson')):
|
||||
|
@ -7,7 +7,7 @@ from ultralytics.nn.tasks import ClassificationModel, DetectionModel, Segmentati
|
||||
from ultralytics.yolo.configs import get_config
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, yaml_load
|
||||
from ultralytics.yolo.utils.checks import check_imgsz, check_yaml
|
||||
from ultralytics.yolo.utils.checks import check_yaml
|
||||
from ultralytics.yolo.utils.torch_utils import guess_task_from_head, smart_inference_mode
|
||||
|
||||
# Map head to model, trainer, validator, and predictor classes
|
||||
@ -43,6 +43,7 @@ class YOLO:
|
||||
self.TrainerClass = None # trainer class
|
||||
self.ValidatorClass = None # validator class
|
||||
self.PredictorClass = None # predictor class
|
||||
self.predictor = None # reuse predictor
|
||||
self.model = None # model object
|
||||
self.trainer = None # trainer object
|
||||
self.task = None # task type
|
||||
@ -131,11 +132,12 @@ class YOLO:
|
||||
overrides.update(kwargs)
|
||||
overrides["mode"] = "predict"
|
||||
overrides["save"] = kwargs.get("save", False) # not save files by default
|
||||
predictor = self.PredictorClass(overrides=overrides)
|
||||
|
||||
predictor.args.imgsz = check_imgsz(predictor.args.imgsz, min_dim=2) # check image size
|
||||
predictor.setup(model=self.model, source=source)
|
||||
return predictor(stream=stream, verbose=verbose)
|
||||
if not self.predictor:
|
||||
self.predictor = self.PredictorClass(overrides=overrides)
|
||||
self.predictor.setup_model(model=self.model)
|
||||
else: # only update args if predictor is already setup
|
||||
self.predictor.args = get_config(self.predictor.args, overrides)
|
||||
return self.predictor(source=source, stream=stream, verbose=verbose)
|
||||
|
||||
@smart_inference_mode()
|
||||
def val(self, data=None, **kwargs):
|
||||
@ -170,6 +172,7 @@ class YOLO:
|
||||
args = get_config(config=DEFAULT_CONFIG, overrides=overrides)
|
||||
args.task = self.task
|
||||
|
||||
print(args)
|
||||
exporter = Exporter(overrides=args)
|
||||
exporter(model=self.model)
|
||||
|
||||
@ -224,10 +227,14 @@ class YOLO:
|
||||
def _reset_ckpt_args(args):
|
||||
args.pop("project", None)
|
||||
args.pop("name", None)
|
||||
args.pop("exist_ok", None)
|
||||
args.pop("resume", None)
|
||||
args.pop("batch", None)
|
||||
args.pop("epochs", None)
|
||||
args.pop("cache", None)
|
||||
args.pop("save_json", None)
|
||||
args.pop("half", None)
|
||||
args.pop("v5loader", None)
|
||||
|
||||
# set device to '' to prevent from auto DDP usage
|
||||
args["device"] = ''
|
||||
|
@ -76,15 +76,15 @@ class BasePredictor:
|
||||
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
||||
name = self.args.name or f"{self.args.mode}"
|
||||
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
|
||||
if self.args.save:
|
||||
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||
if self.args.conf is None:
|
||||
self.args.conf = 0.25 # default conf=0.25
|
||||
self.done_setup = False
|
||||
self.done_warmup = False
|
||||
|
||||
# Usable if setup is done
|
||||
self.model = None
|
||||
self.data = self.args.data # data_dict
|
||||
self.bs = None
|
||||
self.imgsz = None
|
||||
self.device = None
|
||||
self.dataset = None
|
||||
self.vid_path, self.vid_writer = None, None
|
||||
@ -105,11 +105,13 @@ class BasePredictor:
|
||||
def postprocess(self, preds, img, orig_img):
|
||||
return preds
|
||||
|
||||
def setup(self, source=None, model=None):
|
||||
def setup_source(self, source=None):
|
||||
if not self.model:
|
||||
raise Exception("setup model before setting up source!")
|
||||
# source
|
||||
source, webcam, screenshot, from_img = self.check_source(source)
|
||||
# model
|
||||
stride, pt = self.setup_model(model)
|
||||
stride, pt = self.model.stride, self.model.pt
|
||||
imgsz = check_imgsz(self.args.imgsz, stride=stride, min_dim=2) # check image size
|
||||
|
||||
# Dataloader
|
||||
@ -143,14 +145,12 @@ class BasePredictor:
|
||||
transforms=getattr(self.model.model, 'transforms', None),
|
||||
vid_stride=self.args.vid_stride)
|
||||
self.vid_path, self.vid_writer = [None] * bs, [None] * bs
|
||||
self.model.warmup(imgsz=(1 if pt or self.model.triton else bs, 3, *imgsz)) # warmup
|
||||
|
||||
self.webcam = webcam
|
||||
self.screenshot = screenshot
|
||||
self.from_img = from_img
|
||||
self.imgsz = imgsz
|
||||
self.done_setup = True
|
||||
return model
|
||||
self.bs = bs
|
||||
|
||||
@smart_inference_mode()
|
||||
def __call__(self, source=None, model=None, verbose=False, stream=False):
|
||||
@ -167,8 +167,20 @@ class BasePredictor:
|
||||
|
||||
def stream_inference(self, source=None, model=None, verbose=False):
|
||||
self.run_callbacks("on_predict_start")
|
||||
if not self.done_setup:
|
||||
self.setup(source, model)
|
||||
|
||||
# setup model
|
||||
if not self.model:
|
||||
self.setup_model(model)
|
||||
# setup source. Run every time predict is called
|
||||
self.setup_source(source)
|
||||
# check if save_dir/ label file exists
|
||||
if self.args.save:
|
||||
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||
# warmup model
|
||||
if not self.done_warmup:
|
||||
self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.bs, 3, *self.imgsz))
|
||||
self.done_warmup = True
|
||||
|
||||
self.seen, self.windows, self.dt = 0, [], (ops.Profile(), ops.Profile(), ops.Profile())
|
||||
for batch in self.dataset:
|
||||
self.run_callbacks("on_predict_batch_start")
|
||||
@ -223,11 +235,9 @@ class BasePredictor:
|
||||
device = select_device(self.args.device)
|
||||
model = model or self.args.model
|
||||
self.args.half &= device.type != 'cpu' # half precision only supported on CUDA
|
||||
model = AutoBackend(model, device=device, dnn=self.args.dnn, fp16=self.args.half)
|
||||
self.model = model
|
||||
self.model = AutoBackend(model, device=device, dnn=self.args.dnn, fp16=self.args.half)
|
||||
self.device = device
|
||||
self.model.eval()
|
||||
return model.stride, model.pt
|
||||
|
||||
def check_source(self, source):
|
||||
source = source if source is not None else self.args.source
|
||||
|
@ -85,11 +85,11 @@ class Results:
|
||||
|
||||
def __repr__(self):
|
||||
s = f'Ultralytics YOLO {self.__class__} instance\n' # string
|
||||
if self.boxes:
|
||||
if self.boxes is not None:
|
||||
s = s + self.boxes.__repr__() + '\n'
|
||||
if self.masks:
|
||||
if self.masks is not None:
|
||||
s = s + self.masks.__repr__() + '\n'
|
||||
if self.probs:
|
||||
if self.probs is not None:
|
||||
s = s + self.probs.__repr__()
|
||||
s += f'original size: {self.orig_shape}\n'
|
||||
|
||||
|
@ -205,7 +205,7 @@ class BaseTrainer:
|
||||
self.model = DDP(self.model, device_ids=[rank])
|
||||
# Check imgsz
|
||||
gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride)
|
||||
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs * 2)
|
||||
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs)
|
||||
# Batch size
|
||||
if self.batch_size == -1:
|
||||
if RANK == -1: # single-GPU only, estimate best batch size
|
||||
|
@ -372,7 +372,14 @@ def set_sentry(dsn=None):
|
||||
import sentry_sdk # noqa
|
||||
|
||||
import ultralytics
|
||||
sentry_sdk.init(dsn=dsn, traces_sample_rate=1.0, release=ultralytics.__version__, debug=False)
|
||||
sentry_sdk.init(
|
||||
dsn=dsn,
|
||||
debug=False,
|
||||
traces_sample_rate=1.0,
|
||||
release=ultralytics.__version__,
|
||||
send_default_pii=True,
|
||||
environment='production', # 'dev' or 'production'
|
||||
ignore_errors=[KeyboardInterrupt, torch.cuda.OutOfMemoryError])
|
||||
|
||||
|
||||
def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.1'):
|
||||
|
@ -5,7 +5,7 @@ import torch
|
||||
|
||||
from ultralytics.yolo.engine.predictor import BasePredictor
|
||||
from ultralytics.yolo.engine.results import Results
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, is_git_directory
|
||||
from ultralytics.yolo.utils.plotting import Annotator
|
||||
|
||||
|
||||
@ -67,7 +67,8 @@ class ClassificationPredictor(BasePredictor):
|
||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
||||
def predict(cfg):
|
||||
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
|
||||
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets"
|
||||
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" if is_git_directory() \
|
||||
else "https://ultralytics.com/images/bus.jpg"
|
||||
predictor = ClassificationPredictor(cfg)
|
||||
predictor.predict_cli()
|
||||
|
||||
|
@ -140,10 +140,13 @@ class ClassificationTrainer(BaseTrainer):
|
||||
def train(cfg):
|
||||
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
|
||||
cfg.data = cfg.data or "mnist160" # or yolo.ClassificationDataset("mnist")
|
||||
cfg.lr0 = 0.1
|
||||
cfg.weight_decay = 5e-5
|
||||
cfg.label_smoothing = 0.1
|
||||
cfg.warmup_epochs = 0.0
|
||||
|
||||
# Reproduce ImageNet results
|
||||
# cfg.lr0 = 0.1
|
||||
# cfg.weight_decay = 5e-5
|
||||
# cfg.label_smoothing = 0.1
|
||||
# cfg.warmup_epochs = 0.0
|
||||
|
||||
cfg.device = cfg.device if cfg.device is not None else ''
|
||||
# trainer = ClassificationTrainer(cfg)
|
||||
# trainer.train()
|
||||
|
@ -5,7 +5,7 @@ import torch
|
||||
|
||||
from ultralytics.yolo.engine.predictor import BasePredictor
|
||||
from ultralytics.yolo.engine.results import Results
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, ops
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, is_git_directory, ops
|
||||
from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
|
||||
|
||||
|
||||
@ -84,7 +84,8 @@ class DetectionPredictor(BasePredictor):
|
||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
||||
def predict(cfg):
|
||||
cfg.model = cfg.model or "yolov8n.pt"
|
||||
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets"
|
||||
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" if is_git_directory() \
|
||||
else "https://ultralytics.com/images/bus.jpg"
|
||||
predictor = DetectionPredictor(cfg)
|
||||
predictor.predict_cli()
|
||||
|
||||
|
@ -4,7 +4,7 @@ import hydra
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.engine.results import Results
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, ops
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, is_git_directory, ops
|
||||
from ultralytics.yolo.utils.plotting import colors, save_one_box
|
||||
from ultralytics.yolo.v8.detect.predict import DetectionPredictor
|
||||
|
||||
@ -101,8 +101,8 @@ class SegmentationPredictor(DetectionPredictor):
|
||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
||||
def predict(cfg):
|
||||
cfg.model = cfg.model or "yolov8n-seg.pt"
|
||||
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets"
|
||||
|
||||
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" if is_git_directory() \
|
||||
else "https://ultralytics.com/images/bus.jpg"
|
||||
predictor = SegmentationPredictor(cfg)
|
||||
predictor.predict_cli()
|
||||
|
||||
|
@ -45,6 +45,7 @@ class SegmentationValidator(DetectionValidator):
|
||||
self.jdict = []
|
||||
self.stats = []
|
||||
if self.args.save_json:
|
||||
check_requirements('pycocotools>=2.0.6')
|
||||
self.process = ops.process_mask_upsample # more accurate
|
||||
else:
|
||||
self.process = ops.process_mask # faster
|
||||
@ -189,8 +190,9 @@ class SegmentationValidator(DetectionValidator):
|
||||
self.plot_masks.clear()
|
||||
|
||||
def pred_to_json(self, predn, filename, pred_masks):
|
||||
# Save one JSON result {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
|
||||
from pycocotools.mask import encode
|
||||
# Save one JSON result
|
||||
# Example result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
|
||||
from pycocotools.mask import encode # noqa
|
||||
|
||||
def single_encode(x):
|
||||
rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0]
|
||||
|
Loading…
x
Reference in New Issue
Block a user