mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 05:24:22 +08:00
ultralytics 8.0.45
segment CUDA and DDP callback fixes (#1137)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
parent
bfc078b32f
commit
3765f4f6d9
@ -1,6 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
__version__ = '8.0.44'
|
||||
__version__ = '8.0.45'
|
||||
|
||||
from ultralytics.yolo.engine.model import YOLO
|
||||
from ultralytics.yolo.utils.checks import check_yolo as checks
|
||||
|
@ -48,19 +48,21 @@ CLI_HELP_MSG = \
|
||||
"""
|
||||
|
||||
# Define keys for arg type checks
|
||||
CFG_FLOAT_KEYS = {'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear', 'fl_gamma'}
|
||||
CFG_FRACTION_KEYS = {
|
||||
'dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr', 'label_smoothing',
|
||||
'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud', 'fliplr', 'mosaic', 'mixup', 'copy_paste',
|
||||
'conf', 'iou'} # fractional floats limited to 0.0 - 1.0
|
||||
CFG_INT_KEYS = {
|
||||
'epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride',
|
||||
'line_thickness', 'workspace', 'nbs', 'save_period'}
|
||||
CFG_BOOL_KEYS = {
|
||||
'save', 'exist_ok', 'pretrained', 'verbose', 'deterministic', 'single_cls', 'image_weights', 'rect', 'cos_lr',
|
||||
'overlap_mask', 'val', 'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show', 'save_txt', 'save_conf',
|
||||
'save_crop', 'hide_labels', 'hide_conf', 'visualize', 'augment', 'agnostic_nms', 'retina_masks', 'boxes', 'keras',
|
||||
'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'v5loader'}
|
||||
CFG_FLOAT_KEYS = 'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear', 'fl_gamma'
|
||||
CFG_FRACTION_KEYS = ('dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr',
|
||||
'label_smoothing', 'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud',
|
||||
'fliplr', 'mosaic', 'mixup', 'copy_paste', 'conf', 'iou') # fractional floats limited to 0.0 - 1.0
|
||||
CFG_INT_KEYS = ('epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride',
|
||||
'line_thickness', 'workspace', 'nbs', 'save_period')
|
||||
CFG_BOOL_KEYS = ('save', 'exist_ok', 'pretrained', 'verbose', 'deterministic', 'single_cls', 'image_weights', 'rect',
|
||||
'cos_lr', 'overlap_mask', 'val', 'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show',
|
||||
'save_txt', 'save_conf', 'save_crop', 'hide_labels', 'hide_conf', 'visualize', 'augment',
|
||||
'agnostic_nms', 'retina_masks', 'boxes', 'keras', 'optimize', 'int8', 'dynamic', 'simplify', 'nms',
|
||||
'v5loader')
|
||||
|
||||
# Define valid tasks and modes
|
||||
TASKS = 'detect', 'segment', 'classify'
|
||||
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
|
||||
|
||||
|
||||
def cfg2dict(cfg):
|
||||
@ -196,9 +198,6 @@ def entrypoint(debug=''):
|
||||
LOGGER.info(CLI_HELP_MSG)
|
||||
return
|
||||
|
||||
# Define tasks and modes
|
||||
tasks = 'detect', 'segment', 'classify'
|
||||
modes = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
|
||||
special = {
|
||||
'help': lambda: LOGGER.info(CLI_HELP_MSG),
|
||||
'checks': checks.check_yolo,
|
||||
@ -206,7 +205,7 @@ def entrypoint(debug=''):
|
||||
'settings': lambda: yaml_print(USER_CONFIG_DIR / 'settings.yaml'),
|
||||
'cfg': lambda: yaml_print(DEFAULT_CFG_PATH),
|
||||
'copy-cfg': copy_default_cfg}
|
||||
full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in tasks}, **{k: None for k in modes}, **special}
|
||||
full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special}
|
||||
|
||||
# Define common mis-uses of special commands, i.e. -h, -help, --help
|
||||
special.update({k[0]: v for k, v in special.items()}) # singular
|
||||
@ -240,9 +239,9 @@ def entrypoint(debug=''):
|
||||
except (NameError, SyntaxError, ValueError, AssertionError) as e:
|
||||
check_cfg_mismatch(full_args_dict, {a: ''}, e)
|
||||
|
||||
elif a in tasks:
|
||||
elif a in TASKS:
|
||||
overrides['task'] = a
|
||||
elif a in modes:
|
||||
elif a in MODES:
|
||||
overrides['mode'] = a
|
||||
elif a in special:
|
||||
special[a]()
|
||||
@ -262,10 +261,10 @@ def entrypoint(debug=''):
|
||||
mode = overrides.get('mode', None)
|
||||
if mode is None:
|
||||
mode = DEFAULT_CFG.mode or 'predict'
|
||||
LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {modes}. Using default 'mode={mode}'.")
|
||||
elif mode not in modes:
|
||||
LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {MODES}. Using default 'mode={mode}'.")
|
||||
elif mode not in MODES:
|
||||
if mode not in ('checks', checks):
|
||||
raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {modes}.\n{CLI_HELP_MSG}")
|
||||
raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}")
|
||||
LOGGER.warning("WARNING ⚠️ 'yolo mode=checks' is deprecated. Use 'yolo checks' instead.")
|
||||
checks.check_yolo()
|
||||
return
|
||||
@ -280,11 +279,11 @@ def entrypoint(debug=''):
|
||||
model = YOLO(model)
|
||||
|
||||
# Task
|
||||
# if task and task != model.task:
|
||||
# LOGGER.warning(f"WARNING ⚠️ 'task={task}' conflicts with {model.task} model {overrides['model']}. "
|
||||
# f"Inheriting 'task={model.task}' from {overrides['model']} and ignoring 'task={task}'.")
|
||||
overrides['task'] = overrides.get('task', model.task)
|
||||
model.task = overrides['task']
|
||||
task = overrides.get('task', None)
|
||||
if task is not None and task not in TASKS:
|
||||
raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}")
|
||||
else:
|
||||
model.task = task
|
||||
|
||||
# Mode
|
||||
if mode in {'predict', 'track'} and 'source' not in overrides:
|
||||
|
@ -292,7 +292,10 @@ class Exporter:
|
||||
@try_export
|
||||
def _export_onnx(self, prefix=colorstr('ONNX:')):
|
||||
# YOLOv8 ONNX export
|
||||
check_requirements('onnx>=1.12.0')
|
||||
requirements = ['onnx>=1.12.0']
|
||||
if self.args.simplify:
|
||||
requirements += ['onnxsim', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime']
|
||||
check_requirements(requirements)
|
||||
import onnx # noqa
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
|
||||
@ -326,7 +329,6 @@ class Exporter:
|
||||
# Simplify
|
||||
if self.args.simplify:
|
||||
try:
|
||||
check_requirements(('onnxsim', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime'))
|
||||
import onnxsim
|
||||
|
||||
LOGGER.info(f'{prefix} simplifying with onnxsim {onnxsim.__version__}...')
|
||||
@ -508,9 +510,8 @@ class Exporter:
|
||||
try:
|
||||
import tensorflow as tf # noqa
|
||||
except ImportError:
|
||||
check_requirements(
|
||||
f"tensorflow{'-macos' if MACOS else '-aarch64' if ARM64 else '' if torch.cuda.is_available() else '-cpu'}"
|
||||
)
|
||||
cuda = torch.cuda.is_available()
|
||||
check_requirements(f"tensorflow{'-macos' if MACOS else '-aarch64' if ARM64 else '' if cuda else '-cpu'}")
|
||||
import tensorflow as tf # noqa
|
||||
check_requirements(('onnx', 'onnx2tf', 'sng4onnx', 'onnxsim', 'onnx_graphsurgeon', 'tflite_support',
|
||||
'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime'),
|
||||
|
@ -64,7 +64,7 @@ class YOLO:
|
||||
Performs prediction using the YOLO model.
|
||||
|
||||
Returns:
|
||||
list[ultralytics.yolo.engine.results.Results]: The prediction results.
|
||||
list(ultralytics.yolo.engine.results.Results): The prediction results.
|
||||
"""
|
||||
|
||||
def __init__(self, model='yolov8n.pt') -> None:
|
||||
|
@ -111,14 +111,14 @@ class Results:
|
||||
|
||||
Args:
|
||||
show_conf (bool): Whether to show the detection confidence score.
|
||||
line_width (float, optional): The line width of the bounding boxes. If None, it is automatically scaled to the image size.
|
||||
font_size (float, optional): The font size of the text. If None, it is automatically scaled to the image size.
|
||||
line_width (float, optional): The line width of the bounding boxes. If None, it is scaled to the image size.
|
||||
font_size (float, optional): The font size of the text. If None, it is scaled to the image size.
|
||||
font (str): The font to use for the text.
|
||||
pil (bool): Whether to return the image as a PIL Image.
|
||||
example (str): An example string to display in the plot. Useful for indicating the expected format of the output.
|
||||
example (str): An example string to display. Useful for indicating the expected format of the output.
|
||||
|
||||
Returns:
|
||||
None or PIL Image: If `pil` is True, the image will be returned as a PIL Image. Otherwise, nothing is returned.
|
||||
(None) or (PIL.Image): If `pil` is True, a PIL Image is returned. Otherwise, nothing is returned.
|
||||
"""
|
||||
img = deepcopy(self.orig_img)
|
||||
annotator = Annotator(img, line_width, font_size, font, pil, example)
|
||||
@ -284,7 +284,7 @@ class Masks:
|
||||
orig_shape (tuple): Original image size, in the format (height, width).
|
||||
|
||||
Properties:
|
||||
segments (list): A list of segments which includes x, y, w, h, label, confidence, and mask of each detection masks.
|
||||
segments (list): A list of segments which includes x, y, w, h, label, confidence, and mask of each detection.
|
||||
|
||||
Methods:
|
||||
cpu(): Returns a copy of the masks tensor on CPU memory.
|
||||
|
@ -181,7 +181,7 @@ class BaseTrainer:
|
||||
LOGGER.info(f'Running DDP command {cmd}')
|
||||
subprocess.run(cmd, check=True)
|
||||
except Exception as e:
|
||||
LOGGER.warning(e)
|
||||
raise e
|
||||
finally:
|
||||
ddp_cleanup(self, str(file))
|
||||
else:
|
||||
|
@ -63,7 +63,6 @@ class BaseValidator:
|
||||
dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
|
||||
save_dir (Path): Directory to save results.
|
||||
pbar (tqdm.tqdm): Progress bar for displaying progress.
|
||||
logger (logging.Logger): Logger to log messages.
|
||||
args (SimpleNamespace): Configuration for the validator.
|
||||
"""
|
||||
self.dataloader = dataloader
|
||||
|
@ -24,8 +24,6 @@ def find_free_network_port() -> int:
|
||||
def generate_ddp_file(trainer):
|
||||
import_path = '.'.join(str(trainer.__class__).split('.')[1:-1])
|
||||
|
||||
if not trainer.resume:
|
||||
shutil.rmtree(trainer.save_dir) # remove the save_dir
|
||||
content = f'''cfg = {vars(trainer.args)} \nif __name__ == "__main__":
|
||||
from ultralytics.{import_path} import {trainer.__class__.__name__}
|
||||
|
||||
@ -43,16 +41,17 @@ def generate_ddp_file(trainer):
|
||||
|
||||
|
||||
def generate_ddp_command(world_size, trainer):
|
||||
import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
|
||||
|
||||
# Get file and args (do not use sys.argv due to security vulnerability)
|
||||
exclude_args = ['save_dir']
|
||||
args = [f'{k}={v}' for k, v in vars(trainer.args).items() if k not in exclude_args]
|
||||
file = generate_ddp_file(trainer) # if argv[0].endswith('yolo') else os.path.abspath(argv[0])
|
||||
|
||||
# Build command
|
||||
import __main__ # local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
|
||||
file = os.path.abspath(sys.argv[0])
|
||||
using_cli = not file.endswith('.py')
|
||||
if not trainer.resume:
|
||||
shutil.rmtree(trainer.save_dir) # remove the save_dir
|
||||
if using_cli:
|
||||
file = generate_ddp_file(trainer)
|
||||
dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch'
|
||||
port = find_free_network_port()
|
||||
exclude_args = ['save_dir']
|
||||
args = [f'{k}={v}' for k, v in vars(trainer.args).items() if k not in exclude_args]
|
||||
cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file] + args
|
||||
return cmd, file
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user