mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
General console printout updates (#48)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
8530e3fae0
commit
27d6545117
@ -665,7 +665,7 @@ def mosaic_transforms(img_size, hyp):
|
|||||||
perspective=hyp.perspective,
|
perspective=hyp.perspective,
|
||||||
border=[-img_size // 2, -img_size // 2],
|
border=[-img_size // 2, -img_size // 2],
|
||||||
),])
|
),])
|
||||||
transforms = Compose([
|
return Compose([
|
||||||
pre_transform,
|
pre_transform,
|
||||||
MixUp(
|
MixUp(
|
||||||
pre_transform=pre_transform,
|
pre_transform=pre_transform,
|
||||||
@ -674,13 +674,11 @@ def mosaic_transforms(img_size, hyp):
|
|||||||
Albumentations(p=1.0),
|
Albumentations(p=1.0),
|
||||||
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
|
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
|
||||||
RandomFlip(direction="vertical", p=hyp.flipud),
|
RandomFlip(direction="vertical", p=hyp.flipud),
|
||||||
RandomFlip(direction="horizontal", p=hyp.fliplr),])
|
RandomFlip(direction="horizontal", p=hyp.fliplr),]) # transforms
|
||||||
return transforms
|
|
||||||
|
|
||||||
|
|
||||||
def affine_transforms(img_size, hyp):
|
def affine_transforms(img_size, hyp):
|
||||||
# rect, randomperspective, albumentation, hsv, flipud, fliplr
|
return Compose([
|
||||||
transforms = Compose([
|
|
||||||
LetterBox(new_shape=(img_size, img_size)),
|
LetterBox(new_shape=(img_size, img_size)),
|
||||||
RandomPerspective(
|
RandomPerspective(
|
||||||
degrees=hyp.degrees,
|
degrees=hyp.degrees,
|
||||||
@ -693,11 +691,10 @@ def affine_transforms(img_size, hyp):
|
|||||||
Albumentations(p=1.0),
|
Albumentations(p=1.0),
|
||||||
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
|
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
|
||||||
RandomFlip(direction="vertical", p=hyp.flipud),
|
RandomFlip(direction="vertical", p=hyp.flipud),
|
||||||
RandomFlip(direction="horizontal", p=hyp.fliplr),])
|
RandomFlip(direction="horizontal", p=hyp.fliplr),]) # transforms
|
||||||
return transforms
|
|
||||||
|
|
||||||
|
|
||||||
# Classification augmentations -------------------------------------------------------------------------------------------
|
# Classification augmentations -----------------------------------------------------------------------------------------
|
||||||
def classify_transforms(size=224):
|
def classify_transforms(size=224):
|
||||||
# Transforms to apply if albumentations not installed
|
# Transforms to apply if albumentations not installed
|
||||||
assert isinstance(size, int), f"ERROR: classify_transforms size {size} must be integer, not (list, tuple)"
|
assert isinstance(size, int), f"ERROR: classify_transforms size {size} must be integer, not (list, tuple)"
|
||||||
|
@ -9,8 +9,8 @@ import numpy as np
|
|||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ..utils import NUM_THREADS
|
from ..utils import NUM_THREADS, TQDM_BAR_FORMAT
|
||||||
from .utils import BAR_FORMAT, HELP_URL, IMG_FORMATS, LOCAL_RANK
|
from .utils import HELP_URL, IMG_FORMATS, LOCAL_RANK
|
||||||
|
|
||||||
|
|
||||||
class BaseDataset(Dataset):
|
class BaseDataset(Dataset):
|
||||||
@ -18,7 +18,7 @@ class BaseDataset(Dataset):
|
|||||||
Args:
|
Args:
|
||||||
img_path (str): image path.
|
img_path (str): image path.
|
||||||
pipeline (dict): a dict of image transforms.
|
pipeline (dict): a dict of image transforms.
|
||||||
label_path (str): label path, this can also be a ann_file or other custom label path.
|
label_path (str): label path, this can also be an ann_file or other custom label path.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -131,7 +131,7 @@ class BaseDataset(Dataset):
|
|||||||
self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni
|
self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni
|
||||||
fcn = self.cache_images_to_disk if self.cache == "disk" else self.load_image
|
fcn = self.cache_images_to_disk if self.cache == "disk" else self.load_image
|
||||||
results = ThreadPool(NUM_THREADS).imap(fcn, range(self.ni))
|
results = ThreadPool(NUM_THREADS).imap(fcn, range(self.ni))
|
||||||
pbar = tqdm(enumerate(results), total=self.ni, bar_format=BAR_FORMAT, disable=LOCAL_RANK > 0)
|
pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
|
||||||
for i, x in pbar:
|
for i, x in pbar:
|
||||||
if self.cache == "disk":
|
if self.cache == "disk":
|
||||||
gb += self.npy_files[i].stat().st_size
|
gb += self.npy_files[i].stat().st_size
|
||||||
|
@ -6,10 +6,10 @@ from typing import OrderedDict
|
|||||||
import torchvision
|
import torchvision
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ..utils import NUM_THREADS
|
from ..utils import NUM_THREADS, TQDM_BAR_FORMAT
|
||||||
from .augment import *
|
from .augment import *
|
||||||
from .base import BaseDataset
|
from .base import BaseDataset
|
||||||
from .utils import BAR_FORMAT, HELP_URL, LOCAL_RANK, get_hash, img2label_paths, verify_image_label
|
from .utils import HELP_URL, LOCAL_RANK, get_hash, img2label_paths, verify_image_label
|
||||||
|
|
||||||
|
|
||||||
class YOLODataset(BaseDataset):
|
class YOLODataset(BaseDataset):
|
||||||
@ -40,7 +40,7 @@ class YOLODataset(BaseDataset):
|
|||||||
):
|
):
|
||||||
self.use_segments = use_segments
|
self.use_segments = use_segments
|
||||||
self.use_keypoints = use_keypoints
|
self.use_keypoints = use_keypoints
|
||||||
assert not (self.use_segments and self.use_keypoints), "We can't use both of segmentation and pose."
|
assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
|
||||||
super().__init__(img_path, img_size, label_path, cache, augment, hyp, prefix, rect, batch_size, stride, pad,
|
super().__init__(img_path, img_size, label_path, cache, augment, hyp, prefix, rect, batch_size, stride, pad,
|
||||||
single_cls)
|
single_cls)
|
||||||
|
|
||||||
@ -48,14 +48,14 @@ class YOLODataset(BaseDataset):
|
|||||||
# Cache dataset labels, check images and read shapes
|
# Cache dataset labels, check images and read shapes
|
||||||
x = {"labels": []}
|
x = {"labels": []}
|
||||||
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
||||||
desc = f"{self.prefix}Scanning '{path.parent / path.stem}' images and labels..."
|
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
|
||||||
with Pool(NUM_THREADS) as pool:
|
with Pool(NUM_THREADS) as pool:
|
||||||
pbar = tqdm(
|
pbar = tqdm(
|
||||||
pool.imap(verify_image_label,
|
pool.imap(verify_image_label,
|
||||||
zip(self.im_files, self.label_files, repeat(self.prefix), repeat(self.use_keypoints))),
|
zip(self.im_files, self.label_files, repeat(self.prefix), repeat(self.use_keypoints))),
|
||||||
desc=desc,
|
desc=desc,
|
||||||
total=len(self.im_files),
|
total=len(self.im_files),
|
||||||
bar_format=BAR_FORMAT,
|
bar_format=TQDM_BAR_FORMAT,
|
||||||
)
|
)
|
||||||
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
|
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
|
||||||
nm += nm_f
|
nm += nm_f
|
||||||
@ -76,7 +76,7 @@ class YOLODataset(BaseDataset):
|
|||||||
))
|
))
|
||||||
if msg:
|
if msg:
|
||||||
msgs.append(msg)
|
msgs.append(msg)
|
||||||
pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupt"
|
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
||||||
|
|
||||||
pbar.close()
|
pbar.close()
|
||||||
if msgs:
|
if msgs:
|
||||||
@ -109,8 +109,8 @@ class YOLODataset(BaseDataset):
|
|||||||
# Display cache
|
# Display cache
|
||||||
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
|
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
|
||||||
if exists and LOCAL_RANK in {-1, 0}:
|
if exists and LOCAL_RANK in {-1, 0}:
|
||||||
d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupt"
|
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
||||||
tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=BAR_FORMAT) # display cache results
|
tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results
|
||||||
if cache["msgs"]:
|
if cache["msgs"]:
|
||||||
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
||||||
assert nf > 0, f"{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}"
|
assert nf > 0, f"{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}"
|
||||||
|
@ -22,7 +22,6 @@ from ..utils.ops import segments2boxes
|
|||||||
HELP_URL = "See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data"
|
HELP_URL = "See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data"
|
||||||
IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" # include image suffixes
|
IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" # include image suffixes
|
||||||
VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv" # include video suffixes
|
VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv" # include video suffixes
|
||||||
BAR_FORMAT = "{l_bar}{bar:10}{r_bar}{bar:-10b}" # tqdm bar format
|
|
||||||
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||||
RANK = int(os.getenv('RANK', -1))
|
RANK = int(os.getenv('RANK', -1))
|
||||||
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
|
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
|
||||||
|
@ -25,8 +25,8 @@ from tqdm import tqdm
|
|||||||
import ultralytics.yolo.utils as utils
|
import ultralytics.yolo.utils as utils
|
||||||
import ultralytics.yolo.utils.loggers as loggers
|
import ultralytics.yolo.utils.loggers as loggers
|
||||||
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||||
from ultralytics.yolo.utils import LOGGER, ROOT
|
from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT
|
||||||
from ultralytics.yolo.utils.checks import check_file, check_yaml
|
from ultralytics.yolo.utils.checks import print_args
|
||||||
from ultralytics.yolo.utils.files import increment_path, save_yaml
|
from ultralytics.yolo.utils.files import increment_path, save_yaml
|
||||||
from ultralytics.yolo.utils.modeling import get_model
|
from ultralytics.yolo.utils.modeling import get_model
|
||||||
|
|
||||||
@ -41,19 +41,17 @@ class BaseTrainer:
|
|||||||
self.validator = None
|
self.validator = None
|
||||||
self.model = None
|
self.model = None
|
||||||
self.callbacks = defaultdict(list)
|
self.callbacks = defaultdict(list)
|
||||||
self.console.info(f"Training config: \n args: \n {self.args}") # to debug
|
|
||||||
# Directories
|
|
||||||
self.save_dir = increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok)
|
self.save_dir = increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok)
|
||||||
self.wdir = self.save_dir / 'weights'
|
self.wdir = self.save_dir / 'weights' # weights dir
|
||||||
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
||||||
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt'
|
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
|
||||||
|
print_args(dict(self.args))
|
||||||
|
|
||||||
# Save run settings
|
# Save run settings
|
||||||
save_yaml(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True))
|
save_yaml(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True))
|
||||||
|
|
||||||
# device
|
# device
|
||||||
self.device = utils.torch_utils.select_device(self.args.device, self.args.batch_size)
|
self.device = utils.torch_utils.select_device(self.args.device, self.args.batch_size)
|
||||||
self.console.info(f"running on device {self.device}")
|
|
||||||
self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu')
|
self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu')
|
||||||
|
|
||||||
# Model and Dataloaders.
|
# Model and Dataloaders.
|
||||||
@ -64,7 +62,7 @@ class BaseTrainer:
|
|||||||
self.data = check_dataset(self.data)
|
self.data = check_dataset(self.data)
|
||||||
self.trainset, self.testset = self.get_dataset(self.data)
|
self.trainset, self.testset = self.get_dataset(self.data)
|
||||||
if self.args.model:
|
if self.args.model:
|
||||||
self.model = self.get_model(self.args.model, self.data)
|
self.model = self.get_model(self.args.model)
|
||||||
|
|
||||||
# epoch level metrics
|
# epoch level metrics
|
||||||
self.metrics = {} # handle metrics returned by validator
|
self.metrics = {} # handle metrics returned by validator
|
||||||
@ -115,7 +113,7 @@ class BaseTrainer:
|
|||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
mp.spawn(self._do_train, args=(world_size,), nprocs=world_size, join=True)
|
mp.spawn(self._do_train, args=(world_size,), nprocs=world_size, join=True)
|
||||||
else:
|
else:
|
||||||
self._do_train(-1, 1)
|
self._do_train()
|
||||||
|
|
||||||
def _setup_ddp(self, rank, world_size):
|
def _setup_ddp(self, rank, world_size):
|
||||||
os.environ['MASTER_ADDR'] = 'localhost'
|
os.environ['MASTER_ADDR'] = 'localhost'
|
||||||
@ -147,7 +145,7 @@ class BaseTrainer:
|
|||||||
print("created testloader :", rank)
|
print("created testloader :", rank)
|
||||||
self.console.info(self.progress_string())
|
self.console.info(self.progress_string())
|
||||||
|
|
||||||
def _do_train(self, rank, world_size):
|
def _do_train(self, rank=-1, world_size=1):
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
self._setup_ddp(rank, world_size)
|
self._setup_ddp(rank, world_size)
|
||||||
else:
|
else:
|
||||||
@ -165,9 +163,7 @@ class BaseTrainer:
|
|||||||
self.model.train()
|
self.model.train()
|
||||||
pbar = enumerate(self.train_loader)
|
pbar = enumerate(self.train_loader)
|
||||||
if rank in {-1, 0}:
|
if rank in {-1, 0}:
|
||||||
pbar = tqdm(enumerate(self.train_loader),
|
pbar = tqdm(enumerate(self.train_loader), total=len(self.train_loader), bar_format=TQDM_BAR_FORMAT)
|
||||||
total=len(self.train_loader),
|
|
||||||
bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
|
|
||||||
tloss = None
|
tloss = None
|
||||||
for i, batch in pbar:
|
for i, batch in pbar:
|
||||||
# img, label (classification)/ img, targets, paths, _, masks(detection)
|
# img, label (classification)/ img, targets, paths, _, masks(detection)
|
||||||
@ -249,18 +245,14 @@ class BaseTrainer:
|
|||||||
"""
|
"""
|
||||||
return data["train"], data["val"]
|
return data["train"], data["val"]
|
||||||
|
|
||||||
def get_model(self, model: str, data: Dict):
|
def get_model(self, model: Union[str, Path]):
|
||||||
"""
|
"""
|
||||||
load/create/download model for any task
|
load/create/download model for any task
|
||||||
"""
|
"""
|
||||||
pretrained = False
|
pretrained = not str(model).endswith(".yaml")
|
||||||
if not str(model).endswith(".yaml"):
|
return self.load_model(model_cfg=None if pretrained else model,
|
||||||
pretrained = True
|
weights=get_model(model) if pretrained else None,
|
||||||
weights = get_model(model) # rename this to something less confusing?
|
data=self.data) # model
|
||||||
model = self.load_model(model_cfg=model if not pretrained else None,
|
|
||||||
weights=weights if pretrained else None,
|
|
||||||
data=self.data)
|
|
||||||
return model
|
|
||||||
|
|
||||||
def load_model(self, model_cfg, weights, data):
|
def load_model(self, model_cfg, weights, data):
|
||||||
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
||||||
|
@ -5,6 +5,7 @@ from omegaconf import OmegaConf
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
||||||
|
from ultralytics.yolo.utils import TQDM_BAR_FORMAT
|
||||||
from ultralytics.yolo.utils.ops import Profile
|
from ultralytics.yolo.utils.ops import Profile
|
||||||
from ultralytics.yolo.utils.torch_utils import de_parallel, select_device
|
from ultralytics.yolo.utils.torch_utils import de_parallel, select_device
|
||||||
|
|
||||||
@ -49,7 +50,7 @@ class BaseValidator:
|
|||||||
loss = 0
|
loss = 0
|
||||||
n_batches = len(self.dataloader)
|
n_batches = len(self.dataloader)
|
||||||
desc = self.get_desc()
|
desc = self.get_desc()
|
||||||
bar = tqdm(self.dataloader, desc, n_batches, not training, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
|
bar = tqdm(self.dataloader, desc, n_batches, not training, bar_format=TQDM_BAR_FORMAT)
|
||||||
self.init_metrics(de_parallel(model))
|
self.init_metrics(de_parallel(model))
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch_i, batch in enumerate(bar):
|
for batch_i, batch in enumerate(bar):
|
||||||
|
@ -14,6 +14,7 @@ NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiproces
|
|||||||
AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
|
AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
|
||||||
FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
|
FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
|
||||||
VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true' # global verbose mode
|
VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true' # global verbose mode
|
||||||
|
TQDM_BAR_FORMAT = '{l_bar}{bar:10}| {n_fmt}/{total_fmt} {elapsed}' # tqdm bar format
|
||||||
LOGGING_NAME = 'yolov5'
|
LOGGING_NAME = 'yolov5'
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
import glob
|
import glob
|
||||||
|
import inspect
|
||||||
import platform
|
import platform
|
||||||
import sys
|
|
||||||
import urllib
|
import urllib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from subprocess import check_output
|
from subprocess import check_output
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import pkg_resources as pkg
|
import pkg_resources as pkg
|
||||||
import torch
|
import torch
|
||||||
@ -128,3 +129,27 @@ def check_file(file, suffix=''):
|
|||||||
def check_yaml(file, suffix=('.yaml', '.yml')):
|
def check_yaml(file, suffix=('.yaml', '.yml')):
|
||||||
# Search/download YAML file (if necessary) and return path, checking suffix
|
# Search/download YAML file (if necessary) and return path, checking suffix
|
||||||
return check_file(file, suffix)
|
return check_file(file, suffix)
|
||||||
|
|
||||||
|
|
||||||
|
def git_describe(path=ROOT): # path must be a directory
|
||||||
|
# Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
|
||||||
|
try:
|
||||||
|
assert (Path(path) / '.git').is_dir()
|
||||||
|
return check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
|
||||||
|
except Exception:
|
||||||
|
return ''
|
||||||
|
|
||||||
|
|
||||||
|
def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
|
||||||
|
# Print function arguments (optional args dict)
|
||||||
|
x = inspect.currentframe().f_back # previous frame
|
||||||
|
file, _, func, _, _ = inspect.getframeinfo(x)
|
||||||
|
if args is None: # get args automatically
|
||||||
|
args, _, _, frm = inspect.getargvalues(x)
|
||||||
|
args = {k: v for k, v in frm.items() if k in args}
|
||||||
|
try:
|
||||||
|
file = Path(file).resolve().relative_to(ROOT).with_suffix('')
|
||||||
|
except ValueError:
|
||||||
|
file = Path(file).stem
|
||||||
|
s = (f'{file}: ' if show_file else '') + (f'{func}: ' if show_func else '')
|
||||||
|
LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import os
|
import os
|
||||||
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from zipfile import ZipFile
|
from zipfile import ZipFile
|
||||||
|
|
||||||
@ -61,3 +62,15 @@ def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX')):
|
|||||||
for f in zipObj.namelist(): # list all archived filenames in the zip
|
for f in zipObj.namelist(): # list all archived filenames in the zip
|
||||||
if all(x not in f for x in exclude):
|
if all(x not in f for x in exclude):
|
||||||
zipObj.extract(f, path=path)
|
zipObj.extract(f, path=path)
|
||||||
|
|
||||||
|
|
||||||
|
def file_age(path=__file__):
|
||||||
|
# Return days since last file update
|
||||||
|
dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
|
||||||
|
return dt.days # + dt.seconds / 86400 # fractional days
|
||||||
|
|
||||||
|
|
||||||
|
def file_date(path=__file__):
|
||||||
|
# Return human-readable file modification date, i.e. '2021-3-26'
|
||||||
|
t = datetime.fromtimestamp(Path(path).stat().st_mtime)
|
||||||
|
return f'{t.year}-{t.month}-{t.day}'
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import platform
|
||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
@ -12,7 +13,9 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
|
import ultralytics
|
||||||
from ultralytics.yolo.utils import LOGGER
|
from ultralytics.yolo.utils import LOGGER
|
||||||
|
from ultralytics.yolo.utils.checks import git_describe
|
||||||
|
|
||||||
from .checks import check_version
|
from .checks import check_version
|
||||||
|
|
||||||
@ -44,8 +47,8 @@ def DDP_model(model):
|
|||||||
|
|
||||||
def select_device(device='', batch_size=0, newline=True):
|
def select_device(device='', batch_size=0, newline=True):
|
||||||
# device = None or 'cpu' or 0 or '0' or '0,1,2,3'
|
# device = None or 'cpu' or 0 or '0' or '0,1,2,3'
|
||||||
# s = f'YOLOv5 🚀 {git_describe() or file_date()} Python-{platform.python_version()} torch-{torch.__version__} '
|
ver = git_describe() or ultralytics.__version__ # git commit or pip package version
|
||||||
s = f'YOLOv5 🚀 torch-{torch.__version__} '
|
s = f'Ultralytics YOLO 🚀 {ver} Python-{platform.python_version()} torch-{torch.__version__} '
|
||||||
device = str(device).strip().lower().replace('cuda:', '').replace('none', '') # to string, 'cuda:0' to '0'
|
device = str(device).strip().lower().replace('cuda:', '').replace('none', '') # to string, 'cuda:0' to '0'
|
||||||
cpu = device == 'cpu'
|
cpu = device == 'cpu'
|
||||||
mps = device == 'mps' # Apple Metal Performance Shaders (MPS)
|
mps = device == 'mps' # Apple Metal Performance Shaders (MPS)
|
||||||
@ -75,7 +78,7 @@ def select_device(device='', batch_size=0, newline=True):
|
|||||||
|
|
||||||
if not newline:
|
if not newline:
|
||||||
s = s.rstrip()
|
s = s.rstrip()
|
||||||
print(s)
|
LOGGER.info(s)
|
||||||
return torch.device(arg)
|
return torch.device(arg)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,7 +1,3 @@
|
|||||||
import subprocess
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -10,7 +6,6 @@ import torch.nn.functional as F
|
|||||||
from ultralytics.yolo import v8
|
from ultralytics.yolo import v8
|
||||||
from ultralytics.yolo.data import build_dataloader
|
from ultralytics.yolo.data import build_dataloader
|
||||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer
|
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer
|
||||||
from ultralytics.yolo.utils.anchors import check_anchors
|
|
||||||
from ultralytics.yolo.utils.metrics import FocalLoss, bbox_iou, smooth_BCE
|
from ultralytics.yolo.utils.metrics import FocalLoss, bbox_iou, smooth_BCE
|
||||||
from ultralytics.yolo.utils.modeling.tasks import SegmentationModel
|
from ultralytics.yolo.utils.modeling.tasks import SegmentationModel
|
||||||
from ultralytics.yolo.utils.ops import crop_mask, xywh2xyxy
|
from ultralytics.yolo.utils.ops import crop_mask, xywh2xyxy
|
||||||
@ -24,7 +19,7 @@ class SegmentationTrainer(BaseTrainer):
|
|||||||
# TODO: manage splits differently
|
# TODO: manage splits differently
|
||||||
# calculate stride - check if model is initialized
|
# calculate stride - check if model is initialized
|
||||||
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
||||||
loader = build_dataloader(
|
return build_dataloader(
|
||||||
img_path=dataset_path,
|
img_path=dataset_path,
|
||||||
img_size=self.args.img_size,
|
img_size=self.args.img_size,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
@ -38,18 +33,16 @@ class SegmentationTrainer(BaseTrainer):
|
|||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
use_segments=True,
|
use_segments=True,
|
||||||
)[0]
|
)[0]
|
||||||
return loader
|
|
||||||
|
|
||||||
def preprocess_batch(self, batch):
|
def preprocess_batch(self, batch):
|
||||||
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
|
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def load_model(self, model_cfg, weights, data):
|
def load_model(self, model_cfg, weights, data):
|
||||||
model = SegmentationModel(model_cfg if model_cfg else weights["model"].yaml,
|
model = SegmentationModel(model_cfg or weights["model"].yaml,
|
||||||
ch=3,
|
ch=3,
|
||||||
nc=data["nc"],
|
nc=data["nc"],
|
||||||
anchors=self.args.get("anchors"))
|
anchors=self.args.get("anchors"))
|
||||||
check_anchors(model, self.args.anchor_t, self.args.img_size)
|
|
||||||
if weights:
|
if weights:
|
||||||
model.load(weights)
|
model.load(weights)
|
||||||
return model
|
return model
|
||||||
|
@ -1,48 +0,0 @@
|
|||||||
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
|
||||||
|
|
||||||
# Parameters
|
|
||||||
nc: 80 # number of classes
|
|
||||||
depth_multiple: 0.33 # model depth multiple
|
|
||||||
width_multiple: 0.25 # layer channel multiple
|
|
||||||
anchors:
|
|
||||||
- [10,13, 16,30, 33,23] # P3/8
|
|
||||||
- [30,61, 62,45, 59,119] # P4/16
|
|
||||||
- [116,90, 156,198, 373,326] # P5/32
|
|
||||||
|
|
||||||
# YOLOv5 v6.0 backbone
|
|
||||||
backbone:
|
|
||||||
# [from, number, module, args]
|
|
||||||
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
|
||||||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
|
||||||
[-1, 3, C3, [128]],
|
|
||||||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
|
||||||
[-1, 6, C3, [256]],
|
|
||||||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
|
||||||
[-1, 9, C3, [512]],
|
|
||||||
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
|
|
||||||
[-1, 3, C3, [1024]],
|
|
||||||
[-1, 1, SPPF, [1024, 5]], # 9
|
|
||||||
]
|
|
||||||
|
|
||||||
# YOLOv5 v6.0 head
|
|
||||||
head:
|
|
||||||
[[-1, 1, Conv, [512, 1, 1]],
|
|
||||||
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
|
||||||
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
|
||||||
[-1, 3, C3, [512, False]], # 13
|
|
||||||
|
|
||||||
[-1, 1, Conv, [256, 1, 1]],
|
|
||||||
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
|
||||||
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
|
||||||
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
|
|
||||||
|
|
||||||
[-1, 1, Conv, [256, 3, 2]],
|
|
||||||
[[-1, 14], 1, Concat, [1]], # cat head P4
|
|
||||||
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
|
|
||||||
|
|
||||||
[-1, 1, Conv, [512, 3, 2]],
|
|
||||||
[[-1, 10], 1, Concat, [1]], # cat head P5
|
|
||||||
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
|
|
||||||
|
|
||||||
[[17, 20, 23], 1, Segment, [nc, anchors, 32, 256]], # Detect(P3, P4, P5)
|
|
||||||
]
|
|
Loading…
x
Reference in New Issue
Block a user