mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Revert 1783 fix callbacks by reference (#1847)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
2fcca6a0ea
commit
e7a94c79c5
@ -1,6 +1,6 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
|
|
||||||
__version__ = '8.0.63'
|
__version__ = '8.0.64'
|
||||||
|
|
||||||
from ultralytics.hub import start
|
from ultralytics.hub import start
|
||||||
from ultralytics.yolo.engine.model import YOLO
|
from ultralytics.yolo.engine.model import YOLO
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
from copy import deepcopy
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@ -78,7 +77,7 @@ class YOLO:
|
|||||||
task (Any, optional): Task type for the YOLO model. Defaults to None.
|
task (Any, optional): Task type for the YOLO model. Defaults to None.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.callbacks = deepcopy(callbacks.default_callbacks)
|
self._reset_callbacks()
|
||||||
self.predictor = None # reuse predictor
|
self.predictor = None # reuse predictor
|
||||||
self.model = None # model object
|
self.model = None # model object
|
||||||
self.trainer = None # trainer object
|
self.trainer = None # trainer object
|
||||||
@ -118,7 +117,7 @@ class YOLO:
|
|||||||
return any((
|
return any((
|
||||||
model.startswith('https://hub.ultralytics.com/models/'),
|
model.startswith('https://hub.ultralytics.com/models/'),
|
||||||
[len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID
|
[len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID
|
||||||
(len(model) == 20 and not Path(model).exists() and not any(x in model for x in './\\')))) # MODELID
|
len(model) == 20 and not Path(model).exists() and all(x not in model for x in './\\'))) # MODELID
|
||||||
|
|
||||||
def _new(self, cfg: str, task=None, verbose=True):
|
def _new(self, cfg: str, task=None, verbose=True):
|
||||||
"""
|
"""
|
||||||
@ -228,8 +227,8 @@ class YOLO:
|
|||||||
if source is None:
|
if source is None:
|
||||||
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
|
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
|
||||||
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
|
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
|
||||||
is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and (
|
is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and any(
|
||||||
('predict' in sys.argv or 'mode=predict' in sys.argv) or ('track' in sys.argv or 'mode=track' in sys.argv))
|
x in sys.argv for x in ('predict', 'track', 'mode=predict', 'mode=track'))
|
||||||
overrides = self.overrides.copy()
|
overrides = self.overrides.copy()
|
||||||
overrides['conf'] = 0.25
|
overrides['conf'] = 0.25
|
||||||
overrides.update(kwargs) # prefer kwargs
|
overrides.update(kwargs) # prefer kwargs
|
||||||
@ -238,7 +237,7 @@ class YOLO:
|
|||||||
overrides['save'] = kwargs.get('save', False) # not save files by default
|
overrides['save'] = kwargs.get('save', False) # not save files by default
|
||||||
if not self.predictor:
|
if not self.predictor:
|
||||||
self.task = overrides.get('task') or self.task
|
self.task = overrides.get('task') or self.task
|
||||||
self.predictor = TASK_MAP[self.task][3](overrides=overrides, _callbacks=self.callbacks)
|
self.predictor = TASK_MAP[self.task][3](overrides=overrides)
|
||||||
self.predictor.setup_model(model=self.model, verbose=is_cli)
|
self.predictor.setup_model(model=self.model, verbose=is_cli)
|
||||||
else: # only update args if predictor is already setup
|
else: # only update args if predictor is already setup
|
||||||
self.predictor.args = get_cfg(self.predictor.args, overrides)
|
self.predictor.args = get_cfg(self.predictor.args, overrides)
|
||||||
@ -387,17 +386,19 @@ class YOLO:
|
|||||||
"""
|
"""
|
||||||
return self.model.transforms if hasattr(self.model, 'transforms') else None
|
return self.model.transforms if hasattr(self.model, 'transforms') else None
|
||||||
|
|
||||||
def add_callback(self, event: str, func):
|
@staticmethod
|
||||||
|
def add_callback(event: str, func):
|
||||||
"""
|
"""
|
||||||
Add callback
|
Add callback
|
||||||
"""
|
"""
|
||||||
self.callbacks[event].append(func)
|
callbacks.default_callbacks[event].append(func)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reset_ckpt_args(args):
|
def _reset_ckpt_args(args):
|
||||||
include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model
|
include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model
|
||||||
return {k: v for k, v in args.items() if k in include}
|
return {k: v for k, v in args.items() if k in include}
|
||||||
|
|
||||||
def _reset_callbacks(self):
|
@staticmethod
|
||||||
|
def _reset_callbacks():
|
||||||
for event in callbacks.default_callbacks.keys():
|
for event in callbacks.default_callbacks.keys():
|
||||||
self.callbacks[event] = [callbacks.default_callbacks[event][0]]
|
callbacks.default_callbacks[event] = [callbacks.default_callbacks[event][0]]
|
||||||
|
@ -75,7 +75,7 @@ class BasePredictor:
|
|||||||
data_path (str): Path to data.
|
data_path (str): Path to data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None):
|
||||||
"""
|
"""
|
||||||
Initializes the BasePredictor class.
|
Initializes the BasePredictor class.
|
||||||
|
|
||||||
@ -104,7 +104,7 @@ class BasePredictor:
|
|||||||
self.data_path = None
|
self.data_path = None
|
||||||
self.source_type = None
|
self.source_type = None
|
||||||
self.batch = None
|
self.batch = None
|
||||||
self.callbacks = defaultdict(list, _callbacks) if _callbacks else defaultdict(list, callbacks.default_callbacks)
|
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
|
||||||
callbacks.add_integration_callbacks(self)
|
callbacks.add_integration_callbacks(self)
|
||||||
|
|
||||||
def preprocess(self, img):
|
def preprocess(self, img):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user