mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
ultralytics 8.0.29
DDP-cls and default arg fixes (#813)
This commit is contained in:
parent
21ae321bc2
commit
7a7c8dc7b7
@ -1,6 +1,6 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
|
|
||||||
__version__ = "8.0.28"
|
__version__ = "8.0.29"
|
||||||
|
|
||||||
from ultralytics.yolo.engine.model import YOLO
|
from ultralytics.yolo.engine.model import YOLO
|
||||||
from ultralytics.yolo.utils import ops
|
from ultralytics.yolo.utils import ops
|
||||||
|
@ -262,8 +262,8 @@ def entrypoint(debug=''):
|
|||||||
LOGGER.warning(f"WARNING ⚠️ 'format=' is missing. Using default 'format={overrides['format']}'.")
|
LOGGER.warning(f"WARNING ⚠️ 'format=' is missing. Using default 'format={overrides['format']}'.")
|
||||||
|
|
||||||
# Run command in python
|
# Run command in python
|
||||||
cfg = get_cfg(overrides=overrides)
|
# getattr(model, mode)(**vars(get_cfg(overrides=overrides))) # default args using default.yaml
|
||||||
getattr(model, mode)(**vars(cfg))
|
getattr(model, mode)(**overrides) # default args from model
|
||||||
|
|
||||||
|
|
||||||
# Special modes --------------------------------------------------------------------------------------------------------
|
# Special modes --------------------------------------------------------------------------------------------------------
|
||||||
|
@ -184,9 +184,6 @@ class Exporter:
|
|||||||
y = model(im) # dry runs
|
y = model(im) # dry runs
|
||||||
if self.args.half and not coreml and not xml:
|
if self.args.half and not coreml and not xml:
|
||||||
im, model = im.half(), model.half() # to FP16
|
im, model = im.half(), model.half() # to FP16
|
||||||
shape = tuple((y[0] if isinstance(y, tuple) else y).shape) # model output shape
|
|
||||||
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} and "
|
|
||||||
f"output shape {shape} ({file_size(file):.1f} MB)")
|
|
||||||
|
|
||||||
# Warnings
|
# Warnings
|
||||||
warnings.filterwarnings('ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
|
warnings.filterwarnings('ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
|
||||||
@ -207,6 +204,9 @@ class Exporter:
|
|||||||
'stride': int(max(model.stride)),
|
'stride': int(max(model.stride)),
|
||||||
'names': model.names} # model metadata
|
'names': model.names} # model metadata
|
||||||
|
|
||||||
|
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} and "
|
||||||
|
f"output shape {self.output_shape} ({file_size(file):.1f} MB)")
|
||||||
|
|
||||||
# Exports
|
# Exports
|
||||||
f = [''] * len(fmts) # exported filenames
|
f = [''] * len(fmts) # exported filenames
|
||||||
if jit: # TorchScript
|
if jit: # TorchScript
|
||||||
@ -220,9 +220,8 @@ class Exporter:
|
|||||||
if coreml: # CoreML
|
if coreml: # CoreML
|
||||||
f[4], _ = self._export_coreml()
|
f[4], _ = self._export_coreml()
|
||||||
if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
|
if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
|
||||||
raise NotImplementedError('YOLOv8 TensorFlow export support is still under development. '
|
LOGGER.warning('WARNING ⚠️ YOLOv8 TensorFlow export support is still under development. '
|
||||||
'Please consider contributing to the effort if you have TF expertise. Thank you!')
|
'Please consider contributing to the effort if you have TF expertise. Thank you!')
|
||||||
assert not isinstance(model, ClassificationModel), 'ClassificationModel TF exports not yet supported.'
|
|
||||||
nms = False
|
nms = False
|
||||||
f[5], s_model = self._export_saved_model(nms=nms or self.args.agnostic_nms or tfjs,
|
f[5], s_model = self._export_saved_model(nms=nms or self.args.agnostic_nms or tfjs,
|
||||||
agnostic_nms=self.args.agnostic_nms or tfjs)
|
agnostic_nms=self.args.agnostic_nms or tfjs)
|
||||||
@ -236,7 +235,7 @@ class Exporter:
|
|||||||
agnostic_nms=self.args.agnostic_nms)
|
agnostic_nms=self.args.agnostic_nms)
|
||||||
if edgetpu:
|
if edgetpu:
|
||||||
f[8], _ = self._export_edgetpu()
|
f[8], _ = self._export_edgetpu()
|
||||||
self._add_tflite_metadata(f[8] or f[7], num_outputs=len(s_model.outputs))
|
self._add_tflite_metadata(f[8] or f[7], num_outputs=len(self.output_shape))
|
||||||
if tfjs:
|
if tfjs:
|
||||||
f[9], _ = self._export_tfjs()
|
f[9], _ = self._export_tfjs()
|
||||||
if paddle: # PaddlePaddle
|
if paddle: # PaddlePaddle
|
||||||
@ -552,13 +551,13 @@ class Exporter:
|
|||||||
return f, keras_model
|
return f, keras_model
|
||||||
|
|
||||||
@try_export
|
@try_export
|
||||||
def _export_pb(self, keras_model, file, prefix=colorstr('TensorFlow GraphDef:')):
|
def _export_pb(self, keras_model, prefix=colorstr('TensorFlow GraphDef:')):
|
||||||
# YOLOv8 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow
|
# YOLOv8 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow
|
||||||
import tensorflow as tf # noqa
|
import tensorflow as tf # noqa
|
||||||
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa
|
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa
|
||||||
|
|
||||||
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
||||||
f = file.with_suffix('.pb')
|
f = self.file.with_suffix('.pb')
|
||||||
|
|
||||||
m = tf.function(lambda x: keras_model(x)) # full model
|
m = tf.function(lambda x: keras_model(x)) # full model
|
||||||
m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
|
m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
|
||||||
|
@ -119,7 +119,6 @@ class YOLO:
|
|||||||
def fuse(self):
|
def fuse(self):
|
||||||
self.model.fuse()
|
self.model.fuse()
|
||||||
|
|
||||||
@smart_inference_mode()
|
|
||||||
def predict(self, source=None, stream=False, **kwargs):
|
def predict(self, source=None, stream=False, **kwargs):
|
||||||
"""
|
"""
|
||||||
Perform prediction using the YOLO model.
|
Perform prediction using the YOLO model.
|
||||||
@ -258,8 +257,6 @@ class YOLO:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reset_ckpt_args(args):
|
def _reset_ckpt_args(args):
|
||||||
for arg in 'verbose', 'project', 'name', 'exist_ok', 'resume', 'batch', 'epochs', 'cache', 'save_json', \
|
for arg in 'augment', 'verbose', 'project', 'name', 'exist_ok', 'resume', 'batch', 'epochs', 'cache', \
|
||||||
'half', 'v5loader':
|
'save_json', 'half', 'v5loader', 'device', 'cfg', 'save', 'rect', 'plots':
|
||||||
args.pop(arg, None)
|
args.pop(arg, None)
|
||||||
|
|
||||||
args["device"] = '' # set device to '' to prevent auto-DDP usage
|
|
||||||
|
@ -457,7 +457,7 @@ class BaseTrainer:
|
|||||||
def get_validator(self):
|
def get_validator(self):
|
||||||
raise NotImplementedError("get_validator function not implemented in trainer")
|
raise NotImplementedError("get_validator function not implemented in trainer")
|
||||||
|
|
||||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0):
|
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
||||||
"""
|
"""
|
||||||
Returns dataloader derived from torch.data.Dataloader.
|
Returns dataloader derived from torch.data.Dataloader.
|
||||||
"""
|
"""
|
||||||
|
@ -485,18 +485,20 @@ def set_sentry():
|
|||||||
|
|
||||||
if SETTINGS['sync'] and \
|
if SETTINGS['sync'] and \
|
||||||
RANK in {-1, 0} and \
|
RANK in {-1, 0} and \
|
||||||
|
sys.argv[0].endswith('yolo') and \
|
||||||
not is_pytest_running() and \
|
not is_pytest_running() and \
|
||||||
not is_github_actions_ci() and \
|
not is_github_actions_ci() and \
|
||||||
((is_pip_package() and not is_git_dir()) or
|
((is_pip_package() and not is_git_dir()) or
|
||||||
(get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git" and get_git_branch() == "main")):
|
(get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git" and get_git_branch() == "main")):
|
||||||
import sentry_sdk # noqa
|
|
||||||
|
|
||||||
import ultralytics
|
import sentry_sdk # noqa
|
||||||
|
from ultralytics import __version__
|
||||||
|
|
||||||
sentry_sdk.init(
|
sentry_sdk.init(
|
||||||
dsn="https://1f331c322109416595df20a91f4005d3@o4504521589325824.ingest.sentry.io/4504521592406016",
|
dsn="https://f805855f03bb4363bc1e16cb7d87b654@o4504521589325824.ingest.sentry.io/4504521592406016",
|
||||||
debug=False,
|
debug=False,
|
||||||
traces_sample_rate=1.0,
|
traces_sample_rate=1.0,
|
||||||
release=ultralytics.__version__,
|
release=__version__,
|
||||||
environment='production', # 'dev' or 'production'
|
environment='production', # 'dev' or 'production'
|
||||||
before_send=before_send,
|
before_send=before_send,
|
||||||
ignore_errors=[KeyboardInterrupt, FileNotFoundError])
|
ignore_errors=[KeyboardInterrupt, FileNotFoundError])
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
|
import contextlib
|
||||||
import glob
|
import glob
|
||||||
import inspect
|
import inspect
|
||||||
import math
|
import math
|
||||||
@ -7,9 +7,9 @@ import os
|
|||||||
import platform
|
import platform
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
|
import subprocess
|
||||||
import urllib
|
import urllib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from subprocess import check_output
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
@ -155,11 +155,10 @@ def check_online() -> bool:
|
|||||||
bool: True if connection is successful, False otherwise.
|
bool: True if connection is successful, False otherwise.
|
||||||
"""
|
"""
|
||||||
import socket
|
import socket
|
||||||
try:
|
with contextlib.suppress(subprocess.CalledProcessError):
|
||||||
# Check host accessibility by attempting to establish a connection
|
host = socket.gethostbyname("www.github.com")
|
||||||
socket.create_connection(("1.1.1.1", 443), timeout=5)
|
socket.create_connection((host, 80), timeout=2)
|
||||||
return True
|
return True
|
||||||
except OSError:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@ -181,6 +180,7 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
|
|||||||
# Check installed dependencies meet YOLOv5 requirements (pass *.txt file or list of packages or single package str)
|
# Check installed dependencies meet YOLOv5 requirements (pass *.txt file or list of packages or single package str)
|
||||||
prefix = colorstr('red', 'bold', 'requirements:')
|
prefix = colorstr('red', 'bold', 'requirements:')
|
||||||
check_python() # check python version
|
check_python() # check python version
|
||||||
|
file = None
|
||||||
if isinstance(requirements, Path): # requirements.txt file
|
if isinstance(requirements, Path): # requirements.txt file
|
||||||
file = requirements.resolve()
|
file = requirements.resolve()
|
||||||
assert file.exists(), f"{prefix} {file} not found, check failed."
|
assert file.exists(), f"{prefix} {file} not found, check failed."
|
||||||
@ -202,9 +202,8 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
|
|||||||
LOGGER.info(f"{prefix} YOLOv8 requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...")
|
LOGGER.info(f"{prefix} YOLOv8 requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...")
|
||||||
try:
|
try:
|
||||||
assert check_online(), "AutoUpdate skipped (offline)"
|
assert check_online(), "AutoUpdate skipped (offline)"
|
||||||
LOGGER.info(check_output(f'pip install {s} {cmds}', shell=True).decode())
|
LOGGER.info(subprocess.check_output(f'pip install {s} {cmds}', shell=True).decode())
|
||||||
source = file if 'file' in locals() else requirements
|
s = f"{prefix} {n} package{'s' * (n > 1)} updated per {file or requirements}\n" \
|
||||||
s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
|
|
||||||
f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
|
f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
|
||||||
LOGGER.info(s)
|
LOGGER.info(s)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -306,7 +305,7 @@ def git_describe(path=ROOT): # path must be a directory
|
|||||||
# Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
|
# Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
|
||||||
try:
|
try:
|
||||||
assert (Path(path) / '.git').is_dir()
|
assert (Path(path) / '.git').is_dir()
|
||||||
return check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
|
return subprocess.check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
|
||||||
except AssertionError:
|
except AssertionError:
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
|
@ -246,7 +246,7 @@ def intersect_dicts(da, db, exclude=()):
|
|||||||
|
|
||||||
def is_parallel(model):
|
def is_parallel(model):
|
||||||
# Returns True if model is of type DP or DDP
|
# Returns True if model is of type DP or DDP
|
||||||
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel))
|
||||||
|
|
||||||
|
|
||||||
def de_parallel(model):
|
def de_parallel(model):
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
import sys
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
@ -9,7 +8,7 @@ from ultralytics.yolo import v8
|
|||||||
from ultralytics.yolo.data import build_classification_dataloader
|
from ultralytics.yolo.data import build_classification_dataloader
|
||||||
from ultralytics.yolo.engine.trainer import BaseTrainer
|
from ultralytics.yolo.engine.trainer import BaseTrainer
|
||||||
from ultralytics.yolo.utils import DEFAULT_CFG
|
from ultralytics.yolo.utils import DEFAULT_CFG
|
||||||
from ultralytics.yolo.utils.torch_utils import strip_optimizer
|
from ultralytics.yolo.utils.torch_utils import strip_optimizer, is_parallel
|
||||||
|
|
||||||
|
|
||||||
class ClassificationTrainer(BaseTrainer):
|
class ClassificationTrainer(BaseTrainer):
|
||||||
@ -56,7 +55,7 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
|
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
|
||||||
if model.endswith(".pt"):
|
if model.endswith(".pt"):
|
||||||
self.model, _ = attempt_load_one_weight(model, device='cpu')
|
self.model, _ = attempt_load_one_weight(model, device='cpu')
|
||||||
for p in model.parameters():
|
for p in self.model.parameters():
|
||||||
p.requires_grad = True # for training
|
p.requires_grad = True # for training
|
||||||
elif model.endswith(".yaml"):
|
elif model.endswith(".yaml"):
|
||||||
self.model = self.get_model(cfg=model)
|
self.model = self.get_model(cfg=model)
|
||||||
@ -75,8 +74,12 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
augment=mode == "train",
|
augment=mode == "train",
|
||||||
rank=rank,
|
rank=rank,
|
||||||
workers=self.args.workers)
|
workers=self.args.workers)
|
||||||
|
# Attach inference transforms
|
||||||
if mode != "train":
|
if mode != "train":
|
||||||
self.model.transforms = loader.dataset.torch_transforms # attach inference transforms
|
if is_parallel(self.model):
|
||||||
|
self.model.module.transforms = loader.dataset.torch_transforms
|
||||||
|
else:
|
||||||
|
self.model.transforms = loader.dataset.torch_transforms
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
def preprocess_batch(self, batch):
|
def preprocess_batch(self, batch):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user