mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Fix yolo checks
as a package bug in Colab (#972)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sergio Sanchez <sergio.ssm.97@gmail.com>
This commit is contained in:
parent
bdc6cd4d8b
commit
1ad7e79033
@ -62,7 +62,7 @@ full documentation on training, validation, prediction and deployment.
|
|||||||
|
|
||||||
Pip install the ultralytics package including
|
Pip install the ultralytics package including
|
||||||
all [requirements.txt](https://github.com/ultralytics/ultralytics/blob/main/requirements.txt) in a
|
all [requirements.txt](https://github.com/ultralytics/ultralytics/blob/main/requirements.txt) in a
|
||||||
[**3.10>=Python>=3.7**](https://www.python.org/) environment, including
|
[**Python>=3.7**](https://www.python.org/) environment with
|
||||||
[**PyTorch>=1.7**](https://pytorch.org/get-started/locally/).
|
[**PyTorch>=1.7**](https://pytorch.org/get-started/locally/).
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
@ -53,7 +53,7 @@
|
|||||||
<details open>
|
<details open>
|
||||||
<summary>安装</summary>
|
<summary>安装</summary>
|
||||||
|
|
||||||
Pip 安装包含所有 [requirements.txt](https://github.com/ultralytics/ultralytics/blob/main/requirements.txt) 的 ultralytics 包,环境要求 [**3.10>=Python>=3.7**](https://www.python.org/),且 [**PyTorch>=1.7**](https://pytorch.org/get-started/locally/)。
|
Pip 安装包含所有 [requirements.txt](https://github.com/ultralytics/ultralytics/blob/main/requirements.txt) 的 ultralytics 包,环境要求 [**Python>=3.7**](https://www.python.org/),且 [**PyTorch>=1.7**](https://pytorch.org/get-started/locally/)。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install ultralytics
|
pip install ultralytics
|
||||||
|
@ -30,6 +30,7 @@ seaborn>=0.11.0
|
|||||||
# nvidia-tensorrt # TensorRT export
|
# nvidia-tensorrt # TensorRT export
|
||||||
# scikit-learn==0.19.2 # CoreML quantization
|
# scikit-learn==0.19.2 # CoreML quantization
|
||||||
# tensorflow>=2.4.1 # TF exports (-cpu, -aarch64, -macos)
|
# tensorflow>=2.4.1 # TF exports (-cpu, -aarch64, -macos)
|
||||||
|
# tflite-support
|
||||||
# tensorflowjs>=3.9.0 # TF.js export
|
# tensorflowjs>=3.9.0 # TF.js export
|
||||||
# openvino-dev>=2022.3 # OpenVINO export
|
# openvino-dev>=2022.3 # OpenVINO export
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
|
|
||||||
|
import platform
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
@ -14,6 +15,7 @@ from ultralytics.yolo.utils import ROOT, SETTINGS
|
|||||||
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n.pt'
|
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n.pt'
|
||||||
CFG = 'yolov8n.yaml'
|
CFG = 'yolov8n.yaml'
|
||||||
SOURCE = ROOT / 'assets/bus.jpg'
|
SOURCE = ROOT / 'assets/bus.jpg'
|
||||||
|
MACOS = platform.system() == 'Darwin' # macOS environment
|
||||||
|
|
||||||
|
|
||||||
def test_model_forward():
|
def test_model_forward():
|
||||||
@ -120,10 +122,11 @@ def test_export_openvino():
|
|||||||
YOLO(f)(SOURCE) # exported model inference
|
YOLO(f)(SOURCE) # exported model inference
|
||||||
|
|
||||||
|
|
||||||
def test_export_coreml():
|
def test_export_coreml(): # sourcery skip: move-assign
|
||||||
model = YOLO(MODEL)
|
model = YOLO(MODEL)
|
||||||
model.export(format='coreml')
|
f = model.export(format='coreml')
|
||||||
# YOLO(f)(SOURCE) # model prediction only supported on macOS
|
if MACOS:
|
||||||
|
YOLO(f)(SOURCE) # model prediction only supported on macOS
|
||||||
|
|
||||||
|
|
||||||
def test_export_paddle(enabled=False):
|
def test_export_paddle(enabled=False):
|
||||||
|
@ -12,7 +12,7 @@ import torch.nn as nn
|
|||||||
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
|
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
|
||||||
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
|
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
|
||||||
GhostBottleneck, GhostConv, Segment)
|
GhostBottleneck, GhostConv, Segment)
|
||||||
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, yaml_load
|
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, colorstr, yaml_load
|
||||||
from ultralytics.yolo.utils.checks import check_requirements, check_yaml
|
from ultralytics.yolo.utils.checks import check_requirements, check_yaml
|
||||||
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights,
|
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights,
|
||||||
intersect_dicts, make_divisible, model_info, scale_img, time_sync)
|
intersect_dicts, make_divisible, model_info, scale_img, time_sync)
|
||||||
@ -239,7 +239,7 @@ class DetectionModel(BaseModel):
|
|||||||
csd = weights.float().state_dict() # checkpoint state_dict as FP32
|
csd = weights.float().state_dict() # checkpoint state_dict as FP32
|
||||||
csd = intersect_dicts(csd, self.state_dict()) # intersect
|
csd = intersect_dicts(csd, self.state_dict()) # intersect
|
||||||
self.load_state_dict(csd, strict=False) # load
|
self.load_state_dict(csd, strict=False) # load
|
||||||
if verbose:
|
if verbose and RANK == -1:
|
||||||
LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights')
|
LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights')
|
||||||
|
|
||||||
|
|
||||||
|
@ -216,6 +216,9 @@ def entrypoint(debug=''):
|
|||||||
|
|
||||||
overrides = {} # basic overrides, i.e. imgsz=320
|
overrides = {} # basic overrides, i.e. imgsz=320
|
||||||
for a in merge_equals_args(args): # merge spaces around '=' sign
|
for a in merge_equals_args(args): # merge spaces around '=' sign
|
||||||
|
if a.startswith('--'):
|
||||||
|
LOGGER.warning(f"WARNING ⚠️ '{a}' does not require leading dashes '--', updating to '{a[2:]}'.")
|
||||||
|
a = a[2:]
|
||||||
if '=' in a:
|
if '=' in a:
|
||||||
try:
|
try:
|
||||||
re.sub(r' *= *', '=', a) # remove spaces around equals sign
|
re.sub(r' *= *', '=', a) # remove spaces around equals sign
|
||||||
@ -263,7 +266,7 @@ def entrypoint(debug=''):
|
|||||||
mode = DEFAULT_CFG.mode or 'predict'
|
mode = DEFAULT_CFG.mode or 'predict'
|
||||||
LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {modes}. Using default 'mode={mode}'.")
|
LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {modes}. Using default 'mode={mode}'.")
|
||||||
elif mode not in modes:
|
elif mode not in modes:
|
||||||
if mode != 'checks':
|
if mode not in ('checks', checks):
|
||||||
raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {modes}.\n{CLI_HELP_MSG}")
|
raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {modes}.\n{CLI_HELP_MSG}")
|
||||||
LOGGER.warning("WARNING ⚠️ 'yolo mode=checks' is deprecated. Use 'yolo checks' instead.")
|
LOGGER.warning("WARNING ⚠️ 'yolo mode=checks' is deprecated. Use 'yolo checks' instead.")
|
||||||
checks.check_yolo()
|
checks.check_yolo()
|
||||||
|
@ -206,7 +206,7 @@ class Exporter:
|
|||||||
self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else tuple(tuple(x.shape) for x in y)
|
self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else tuple(tuple(x.shape) for x in y)
|
||||||
self.pretty_name = self.file.stem.replace('yolo', 'YOLO')
|
self.pretty_name = self.file.stem.replace('yolo', 'YOLO')
|
||||||
self.metadata = {
|
self.metadata = {
|
||||||
'description': f"Ultralytics {self.pretty_name} model trained on {self.model.args['data']}",
|
'description': f"Ultralytics {self.pretty_name} model trained on {self.args.data}",
|
||||||
'author': 'Ultralytics',
|
'author': 'Ultralytics',
|
||||||
'license': 'GPL-3.0 https://ultralytics.com/license',
|
'license': 'GPL-3.0 https://ultralytics.com/license',
|
||||||
'version': __version__,
|
'version': __version__,
|
||||||
@ -257,11 +257,16 @@ class Exporter:
|
|||||||
f = [str(x) for x in f if x] # filter out '' and None
|
f = [str(x) for x in f if x] # filter out '' and None
|
||||||
if any(f):
|
if any(f):
|
||||||
f = str(Path(f[-1]))
|
f = str(Path(f[-1]))
|
||||||
LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
|
square = self.imgsz[0] == self.imgsz[1]
|
||||||
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
|
s = f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not work. Use " \
|
||||||
f"\nPredict: yolo task={model.task} mode=predict model={f}"
|
f"export 'imgsz={max(self.imgsz)}' if val is required." if not square else ''
|
||||||
f"\nValidate: yolo task={model.task} mode=val model={f}"
|
imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(' ', '')
|
||||||
f"\nVisualize: https://netron.app")
|
LOGGER.info(
|
||||||
|
f'\nExport complete ({time.time() - t:.1f}s)'
|
||||||
|
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
|
||||||
|
f"\nPredict: yolo task={model.task} mode=predict model={f} imgsz={imgsz}"
|
||||||
|
f"\nValidate: yolo task={model.task} mode=val model={f} imgsz={imgsz} data={self.args.data} {s}"
|
||||||
|
f"\nVisualize: https://netron.app")
|
||||||
|
|
||||||
self.run_callbacks("on_export_end")
|
self.run_callbacks("on_export_end")
|
||||||
return f # return list of exported files/dirs
|
return f # return list of exported files/dirs
|
||||||
@ -497,7 +502,7 @@ class Exporter:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
check_requirements(f"tensorflow{'' if torch.cuda.is_available() else '-macos' if MACOS else '-cpu'}")
|
check_requirements(f"tensorflow{'' if torch.cuda.is_available() else '-macos' if MACOS else '-cpu'}")
|
||||||
import tensorflow as tf # noqa
|
import tensorflow as tf # noqa
|
||||||
check_requirements(("onnx", "onnx2tf", "sng4onnx", "onnxsim", "onnx_graphsurgeon"),
|
check_requirements(("onnx", "onnx2tf", "sng4onnx", "onnxsim", "onnx_graphsurgeon", "tflite_support"),
|
||||||
cmds="--extra-index-url https://pypi.ngc.nvidia.com ")
|
cmds="--extra-index-url https://pypi.ngc.nvidia.com ")
|
||||||
|
|
||||||
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
||||||
@ -680,24 +685,45 @@ class Exporter:
|
|||||||
|
|
||||||
def _add_tflite_metadata(self, file):
|
def _add_tflite_metadata(self, file):
|
||||||
# Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata
|
# Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata
|
||||||
check_requirements('tflite_support')
|
|
||||||
|
|
||||||
from tflite_support import flatbuffers # noqa
|
from tflite_support import flatbuffers # noqa
|
||||||
from tflite_support import metadata as _metadata # noqa
|
from tflite_support import metadata as _metadata # noqa
|
||||||
from tflite_support import metadata_schema_py_generated as _metadata_fb # noqa
|
from tflite_support import metadata_schema_py_generated as _metadata_fb # noqa
|
||||||
|
|
||||||
|
# Creates model info.
|
||||||
|
model_meta = _metadata_fb.ModelMetadataT()
|
||||||
|
model_meta.name = self.metadata['description']
|
||||||
|
model_meta.version = self.metadata['version']
|
||||||
|
model_meta.author = self.metadata['author']
|
||||||
|
model_meta.license = self.metadata['license']
|
||||||
|
|
||||||
|
# Creates input info.
|
||||||
|
input_meta = _metadata_fb.TensorMetadataT()
|
||||||
|
input_meta.name = "image"
|
||||||
|
input_meta.description = "Input image to be detected."
|
||||||
|
input_meta.content = _metadata_fb.ContentT()
|
||||||
|
input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
|
||||||
|
input_meta.content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.RGB
|
||||||
|
input_meta.content.contentPropertiesType = _metadata_fb.ContentProperties.ImageProperties
|
||||||
|
|
||||||
|
# Creates output info.
|
||||||
|
output_meta = _metadata_fb.TensorMetadataT()
|
||||||
|
output_meta.name = "output"
|
||||||
|
output_meta.description = "Coordinates of detected objects, class labels, and confidence score."
|
||||||
|
|
||||||
|
# Label file
|
||||||
tmp_file = Path('/tmp/meta.txt')
|
tmp_file = Path('/tmp/meta.txt')
|
||||||
with open(tmp_file, 'w') as meta_f:
|
with open(tmp_file, 'w') as meta_f:
|
||||||
meta_f.write(str(self.metadata))
|
meta_f.write(str(self.metadata))
|
||||||
|
|
||||||
model_meta = _metadata_fb.ModelMetadataT()
|
|
||||||
label_file = _metadata_fb.AssociatedFileT()
|
label_file = _metadata_fb.AssociatedFileT()
|
||||||
label_file.name = tmp_file.name
|
label_file.name = tmp_file.name
|
||||||
model_meta.associatedFiles = [label_file]
|
label_file.type = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS
|
||||||
|
output_meta.associatedFiles = [label_file]
|
||||||
|
|
||||||
|
# Creates subgraph info.
|
||||||
subgraph = _metadata_fb.SubGraphMetadataT()
|
subgraph = _metadata_fb.SubGraphMetadataT()
|
||||||
subgraph.inputTensorMetadata = [_metadata_fb.TensorMetadataT()]
|
subgraph.inputTensorMetadata = [input_meta]
|
||||||
subgraph.outputTensorMetadata = [_metadata_fb.TensorMetadataT()] * len(self.output_shape)
|
subgraph.outputTensorMetadata = [output_meta]
|
||||||
model_meta.subgraphMetadata = [subgraph]
|
model_meta.subgraphMetadata = [subgraph]
|
||||||
|
|
||||||
b = flatbuffers.Builder(0)
|
b = flatbuffers.Builder(0)
|
||||||
@ -710,6 +736,14 @@ class Exporter:
|
|||||||
populator.populate()
|
populator.populate()
|
||||||
tmp_file.unlink()
|
tmp_file.unlink()
|
||||||
|
|
||||||
|
# TODO Rename this here and in `_add_tflite_metadata`
|
||||||
|
def _extracted_from__add_tflite_metadata_15(self, _metadata_fb, arg1, arg2):
|
||||||
|
# Creates input info.
|
||||||
|
result = _metadata_fb.TensorMetadataT()
|
||||||
|
result.name = arg1
|
||||||
|
result.description = arg2
|
||||||
|
return result
|
||||||
|
|
||||||
def _pipeline_coreml(self, model, prefix=colorstr('CoreML Pipeline:')):
|
def _pipeline_coreml(self, model, prefix=colorstr('CoreML Pipeline:')):
|
||||||
# YOLOv8 CoreML pipeline
|
# YOLOv8 CoreML pipeline
|
||||||
import coremltools as ct # noqa
|
import coremltools as ct # noqa
|
||||||
|
@ -81,7 +81,7 @@ class YOLO:
|
|||||||
cfg_dict = yaml_load(self.cfg, append_filename=True) # model dict
|
cfg_dict = yaml_load(self.cfg, append_filename=True) # model dict
|
||||||
self.task = guess_model_task(cfg_dict)
|
self.task = guess_model_task(cfg_dict)
|
||||||
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._assign_ops_from_task()
|
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._assign_ops_from_task()
|
||||||
self.model = self.ModelClass(cfg_dict, verbose=verbose) # initialize
|
self.model = self.ModelClass(cfg_dict, verbose=verbose and RANK == -1) # initialize
|
||||||
|
|
||||||
def _load(self, weights: str):
|
def _load(self, weights: str):
|
||||||
"""
|
"""
|
||||||
@ -240,7 +240,7 @@ class YOLO:
|
|||||||
if RANK in {0, -1}:
|
if RANK in {0, -1}:
|
||||||
self.model, _ = attempt_load_one_weight(str(self.trainer.best))
|
self.model, _ = attempt_load_one_weight(str(self.trainer.best))
|
||||||
self.overrides = self.model.args
|
self.overrides = self.model.args
|
||||||
self.metrics_data = self.trainer.validator.metrics
|
self.metrics_data = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
|
||||||
|
|
||||||
def to(self, device):
|
def to(self, device):
|
||||||
"""
|
"""
|
||||||
|
@ -85,6 +85,7 @@ class BaseTrainer:
|
|||||||
self.console = LOGGER
|
self.console = LOGGER
|
||||||
self.validator = None
|
self.validator = None
|
||||||
self.model = None
|
self.model = None
|
||||||
|
self.metrics = None
|
||||||
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
|
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
|
||||||
|
|
||||||
# Dirs
|
# Dirs
|
||||||
@ -417,7 +418,7 @@ class BaseTrainer:
|
|||||||
cfg = ckpt["model"].yaml
|
cfg = ckpt["model"].yaml
|
||||||
else:
|
else:
|
||||||
cfg = model
|
cfg = model
|
||||||
self.model = self.get_model(cfg=cfg, weights=weights) # calls Model(cfg, weights)
|
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
|
||||||
return ckpt
|
return ckpt
|
||||||
|
|
||||||
def optimizer_step(self):
|
def optimizer_step(self):
|
||||||
|
@ -7,7 +7,7 @@ from ultralytics.nn.tasks import ClassificationModel, attempt_load_one_weight
|
|||||||
from ultralytics.yolo import v8
|
from ultralytics.yolo import v8
|
||||||
from ultralytics.yolo.data import build_classification_dataloader
|
from ultralytics.yolo.data import build_classification_dataloader
|
||||||
from ultralytics.yolo.engine.trainer import BaseTrainer
|
from ultralytics.yolo.engine.trainer import BaseTrainer
|
||||||
from ultralytics.yolo.utils import DEFAULT_CFG
|
from ultralytics.yolo.utils import DEFAULT_CFG, RANK
|
||||||
from ultralytics.yolo.utils.torch_utils import is_parallel, strip_optimizer
|
from ultralytics.yolo.utils.torch_utils import is_parallel, strip_optimizer
|
||||||
|
|
||||||
|
|
||||||
@ -23,7 +23,7 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
self.model.names = self.data["names"]
|
self.model.names = self.data["names"]
|
||||||
|
|
||||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||||
model = ClassificationModel(cfg, nc=self.data["nc"])
|
model = ClassificationModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
|
||||||
if weights:
|
if weights:
|
||||||
model.load(weights)
|
model.load(weights)
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ from ultralytics.yolo import v8
|
|||||||
from ultralytics.yolo.data import build_dataloader
|
from ultralytics.yolo.data import build_dataloader
|
||||||
from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader
|
from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader
|
||||||
from ultralytics.yolo.engine.trainer import BaseTrainer
|
from ultralytics.yolo.engine.trainer import BaseTrainer
|
||||||
from ultralytics.yolo.utils import DEFAULT_CFG, colorstr
|
from ultralytics.yolo.utils import DEFAULT_CFG, RANK, colorstr
|
||||||
from ultralytics.yolo.utils.loss import BboxLoss
|
from ultralytics.yolo.utils.loss import BboxLoss
|
||||||
from ultralytics.yolo.utils.ops import xywh2xyxy
|
from ultralytics.yolo.utils.ops import xywh2xyxy
|
||||||
from ultralytics.yolo.utils.plotting import plot_images, plot_results
|
from ultralytics.yolo.utils.plotting import plot_images, plot_results
|
||||||
@ -57,7 +57,7 @@ class DetectionTrainer(BaseTrainer):
|
|||||||
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
|
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
|
||||||
|
|
||||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||||
model = DetectionModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose)
|
model = DetectionModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
|
||||||
if weights:
|
if weights:
|
||||||
model.load(weights)
|
model.load(weights)
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from ultralytics.nn.tasks import SegmentationModel
|
from ultralytics.nn.tasks import SegmentationModel
|
||||||
from ultralytics.yolo import v8
|
from ultralytics.yolo import v8
|
||||||
from ultralytics.yolo.utils import DEFAULT_CFG
|
from ultralytics.yolo.utils import DEFAULT_CFG, RANK
|
||||||
from ultralytics.yolo.utils.ops import crop_mask, xyxy2xywh
|
from ultralytics.yolo.utils.ops import crop_mask, xyxy2xywh
|
||||||
from ultralytics.yolo.utils.plotting import plot_images, plot_results
|
from ultralytics.yolo.utils.plotting import plot_images, plot_results
|
||||||
from ultralytics.yolo.utils.tal import make_anchors
|
from ultralytics.yolo.utils.tal import make_anchors
|
||||||
@ -24,7 +24,7 @@ class SegmentationTrainer(v8.detect.DetectionTrainer):
|
|||||||
super().__init__(cfg, overrides)
|
super().__init__(cfg, overrides)
|
||||||
|
|
||||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||||
model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose)
|
model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
|
||||||
if weights:
|
if weights:
|
||||||
model.load(weights)
|
model.load(weights)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user