mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
General console printout updates (#48)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
8530e3fae0
commit
27d6545117
@ -665,7 +665,7 @@ def mosaic_transforms(img_size, hyp):
|
||||
perspective=hyp.perspective,
|
||||
border=[-img_size // 2, -img_size // 2],
|
||||
),])
|
||||
transforms = Compose([
|
||||
return Compose([
|
||||
pre_transform,
|
||||
MixUp(
|
||||
pre_transform=pre_transform,
|
||||
@ -674,13 +674,11 @@ def mosaic_transforms(img_size, hyp):
|
||||
Albumentations(p=1.0),
|
||||
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
|
||||
RandomFlip(direction="vertical", p=hyp.flipud),
|
||||
RandomFlip(direction="horizontal", p=hyp.fliplr),])
|
||||
return transforms
|
||||
RandomFlip(direction="horizontal", p=hyp.fliplr),]) # transforms
|
||||
|
||||
|
||||
def affine_transforms(img_size, hyp):
|
||||
# rect, randomperspective, albumentation, hsv, flipud, fliplr
|
||||
transforms = Compose([
|
||||
return Compose([
|
||||
LetterBox(new_shape=(img_size, img_size)),
|
||||
RandomPerspective(
|
||||
degrees=hyp.degrees,
|
||||
@ -693,11 +691,10 @@ def affine_transforms(img_size, hyp):
|
||||
Albumentations(p=1.0),
|
||||
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
|
||||
RandomFlip(direction="vertical", p=hyp.flipud),
|
||||
RandomFlip(direction="horizontal", p=hyp.fliplr),])
|
||||
return transforms
|
||||
RandomFlip(direction="horizontal", p=hyp.fliplr),]) # transforms
|
||||
|
||||
|
||||
# Classification augmentations -------------------------------------------------------------------------------------------
|
||||
# Classification augmentations -----------------------------------------------------------------------------------------
|
||||
def classify_transforms(size=224):
|
||||
# Transforms to apply if albumentations not installed
|
||||
assert isinstance(size, int), f"ERROR: classify_transforms size {size} must be integer, not (list, tuple)"
|
||||
|
@ -9,8 +9,8 @@ import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..utils import NUM_THREADS
|
||||
from .utils import BAR_FORMAT, HELP_URL, IMG_FORMATS, LOCAL_RANK
|
||||
from ..utils import NUM_THREADS, TQDM_BAR_FORMAT
|
||||
from .utils import HELP_URL, IMG_FORMATS, LOCAL_RANK
|
||||
|
||||
|
||||
class BaseDataset(Dataset):
|
||||
@ -18,7 +18,7 @@ class BaseDataset(Dataset):
|
||||
Args:
|
||||
img_path (str): image path.
|
||||
pipeline (dict): a dict of image transforms.
|
||||
label_path (str): label path, this can also be a ann_file or other custom label path.
|
||||
label_path (str): label path, this can also be an ann_file or other custom label path.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -131,7 +131,7 @@ class BaseDataset(Dataset):
|
||||
self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni
|
||||
fcn = self.cache_images_to_disk if self.cache == "disk" else self.load_image
|
||||
results = ThreadPool(NUM_THREADS).imap(fcn, range(self.ni))
|
||||
pbar = tqdm(enumerate(results), total=self.ni, bar_format=BAR_FORMAT, disable=LOCAL_RANK > 0)
|
||||
pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
|
||||
for i, x in pbar:
|
||||
if self.cache == "disk":
|
||||
gb += self.npy_files[i].stat().st_size
|
||||
|
@ -6,10 +6,10 @@ from typing import OrderedDict
|
||||
import torchvision
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..utils import NUM_THREADS
|
||||
from ..utils import NUM_THREADS, TQDM_BAR_FORMAT
|
||||
from .augment import *
|
||||
from .base import BaseDataset
|
||||
from .utils import BAR_FORMAT, HELP_URL, LOCAL_RANK, get_hash, img2label_paths, verify_image_label
|
||||
from .utils import HELP_URL, LOCAL_RANK, get_hash, img2label_paths, verify_image_label
|
||||
|
||||
|
||||
class YOLODataset(BaseDataset):
|
||||
@ -40,7 +40,7 @@ class YOLODataset(BaseDataset):
|
||||
):
|
||||
self.use_segments = use_segments
|
||||
self.use_keypoints = use_keypoints
|
||||
assert not (self.use_segments and self.use_keypoints), "We can't use both of segmentation and pose."
|
||||
assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
|
||||
super().__init__(img_path, img_size, label_path, cache, augment, hyp, prefix, rect, batch_size, stride, pad,
|
||||
single_cls)
|
||||
|
||||
@ -48,14 +48,14 @@ class YOLODataset(BaseDataset):
|
||||
# Cache dataset labels, check images and read shapes
|
||||
x = {"labels": []}
|
||||
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
||||
desc = f"{self.prefix}Scanning '{path.parent / path.stem}' images and labels..."
|
||||
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
|
||||
with Pool(NUM_THREADS) as pool:
|
||||
pbar = tqdm(
|
||||
pool.imap(verify_image_label,
|
||||
zip(self.im_files, self.label_files, repeat(self.prefix), repeat(self.use_keypoints))),
|
||||
desc=desc,
|
||||
total=len(self.im_files),
|
||||
bar_format=BAR_FORMAT,
|
||||
bar_format=TQDM_BAR_FORMAT,
|
||||
)
|
||||
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
|
||||
nm += nm_f
|
||||
@ -76,7 +76,7 @@ class YOLODataset(BaseDataset):
|
||||
))
|
||||
if msg:
|
||||
msgs.append(msg)
|
||||
pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupt"
|
||||
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
||||
|
||||
pbar.close()
|
||||
if msgs:
|
||||
@ -109,8 +109,8 @@ class YOLODataset(BaseDataset):
|
||||
# Display cache
|
||||
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
|
||||
if exists and LOCAL_RANK in {-1, 0}:
|
||||
d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupt"
|
||||
tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=BAR_FORMAT) # display cache results
|
||||
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
||||
tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results
|
||||
if cache["msgs"]:
|
||||
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
||||
assert nf > 0, f"{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}"
|
||||
|
@ -22,7 +22,6 @@ from ..utils.ops import segments2boxes
|
||||
HELP_URL = "See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data"
|
||||
IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" # include image suffixes
|
||||
VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv" # include video suffixes
|
||||
BAR_FORMAT = "{l_bar}{bar:10}{r_bar}{bar:-10b}" # tqdm bar format
|
||||
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||
RANK = int(os.getenv('RANK', -1))
|
||||
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
|
||||
|
@ -25,8 +25,8 @@ from tqdm import tqdm
|
||||
import ultralytics.yolo.utils as utils
|
||||
import ultralytics.yolo.utils.loggers as loggers
|
||||
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||
from ultralytics.yolo.utils import LOGGER, ROOT
|
||||
from ultralytics.yolo.utils.checks import check_file, check_yaml
|
||||
from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT
|
||||
from ultralytics.yolo.utils.checks import print_args
|
||||
from ultralytics.yolo.utils.files import increment_path, save_yaml
|
||||
from ultralytics.yolo.utils.modeling import get_model
|
||||
|
||||
@ -41,19 +41,17 @@ class BaseTrainer:
|
||||
self.validator = None
|
||||
self.model = None
|
||||
self.callbacks = defaultdict(list)
|
||||
self.console.info(f"Training config: \n args: \n {self.args}") # to debug
|
||||
# Directories
|
||||
self.save_dir = increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok)
|
||||
self.wdir = self.save_dir / 'weights'
|
||||
self.wdir = self.save_dir / 'weights' # weights dir
|
||||
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
||||
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt'
|
||||
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
|
||||
print_args(dict(self.args))
|
||||
|
||||
# Save run settings
|
||||
save_yaml(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True))
|
||||
|
||||
# device
|
||||
self.device = utils.torch_utils.select_device(self.args.device, self.args.batch_size)
|
||||
self.console.info(f"running on device {self.device}")
|
||||
self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu')
|
||||
|
||||
# Model and Dataloaders.
|
||||
@ -64,7 +62,7 @@ class BaseTrainer:
|
||||
self.data = check_dataset(self.data)
|
||||
self.trainset, self.testset = self.get_dataset(self.data)
|
||||
if self.args.model:
|
||||
self.model = self.get_model(self.args.model, self.data)
|
||||
self.model = self.get_model(self.args.model)
|
||||
|
||||
# epoch level metrics
|
||||
self.metrics = {} # handle metrics returned by validator
|
||||
@ -115,7 +113,7 @@ class BaseTrainer:
|
||||
if world_size > 1:
|
||||
mp.spawn(self._do_train, args=(world_size,), nprocs=world_size, join=True)
|
||||
else:
|
||||
self._do_train(-1, 1)
|
||||
self._do_train()
|
||||
|
||||
def _setup_ddp(self, rank, world_size):
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
@ -147,7 +145,7 @@ class BaseTrainer:
|
||||
print("created testloader :", rank)
|
||||
self.console.info(self.progress_string())
|
||||
|
||||
def _do_train(self, rank, world_size):
|
||||
def _do_train(self, rank=-1, world_size=1):
|
||||
if world_size > 1:
|
||||
self._setup_ddp(rank, world_size)
|
||||
else:
|
||||
@ -165,9 +163,7 @@ class BaseTrainer:
|
||||
self.model.train()
|
||||
pbar = enumerate(self.train_loader)
|
||||
if rank in {-1, 0}:
|
||||
pbar = tqdm(enumerate(self.train_loader),
|
||||
total=len(self.train_loader),
|
||||
bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
|
||||
pbar = tqdm(enumerate(self.train_loader), total=len(self.train_loader), bar_format=TQDM_BAR_FORMAT)
|
||||
tloss = None
|
||||
for i, batch in pbar:
|
||||
# img, label (classification)/ img, targets, paths, _, masks(detection)
|
||||
@ -249,18 +245,14 @@ class BaseTrainer:
|
||||
"""
|
||||
return data["train"], data["val"]
|
||||
|
||||
def get_model(self, model: str, data: Dict):
|
||||
def get_model(self, model: Union[str, Path]):
|
||||
"""
|
||||
load/create/download model for any task
|
||||
"""
|
||||
pretrained = False
|
||||
if not str(model).endswith(".yaml"):
|
||||
pretrained = True
|
||||
weights = get_model(model) # rename this to something less confusing?
|
||||
model = self.load_model(model_cfg=model if not pretrained else None,
|
||||
weights=weights if pretrained else None,
|
||||
data=self.data)
|
||||
return model
|
||||
pretrained = not str(model).endswith(".yaml")
|
||||
return self.load_model(model_cfg=None if pretrained else model,
|
||||
weights=get_model(model) if pretrained else None,
|
||||
data=self.data) # model
|
||||
|
||||
def load_model(self, model_cfg, weights, data):
|
||||
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
||||
|
@ -5,6 +5,7 @@ from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
||||
from ultralytics.yolo.utils import TQDM_BAR_FORMAT
|
||||
from ultralytics.yolo.utils.ops import Profile
|
||||
from ultralytics.yolo.utils.torch_utils import de_parallel, select_device
|
||||
|
||||
@ -49,7 +50,7 @@ class BaseValidator:
|
||||
loss = 0
|
||||
n_batches = len(self.dataloader)
|
||||
desc = self.get_desc()
|
||||
bar = tqdm(self.dataloader, desc, n_batches, not training, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
|
||||
bar = tqdm(self.dataloader, desc, n_batches, not training, bar_format=TQDM_BAR_FORMAT)
|
||||
self.init_metrics(de_parallel(model))
|
||||
with torch.no_grad():
|
||||
for batch_i, batch in enumerate(bar):
|
||||
|
@ -14,6 +14,7 @@ NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiproces
|
||||
AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
|
||||
FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
|
||||
VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true' # global verbose mode
|
||||
TQDM_BAR_FORMAT = '{l_bar}{bar:10}| {n_fmt}/{total_fmt} {elapsed}' # tqdm bar format
|
||||
LOGGING_NAME = 'yolov5'
|
||||
|
||||
|
||||
|
@ -1,9 +1,10 @@
|
||||
import glob
|
||||
import inspect
|
||||
import platform
|
||||
import sys
|
||||
import urllib
|
||||
from pathlib import Path
|
||||
from subprocess import check_output
|
||||
from typing import Optional
|
||||
|
||||
import pkg_resources as pkg
|
||||
import torch
|
||||
@ -128,3 +129,27 @@ def check_file(file, suffix=''):
|
||||
def check_yaml(file, suffix=('.yaml', '.yml')):
|
||||
# Search/download YAML file (if necessary) and return path, checking suffix
|
||||
return check_file(file, suffix)
|
||||
|
||||
|
||||
def git_describe(path=ROOT): # path must be a directory
|
||||
# Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
|
||||
try:
|
||||
assert (Path(path) / '.git').is_dir()
|
||||
return check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
|
||||
except Exception:
|
||||
return ''
|
||||
|
||||
|
||||
def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
|
||||
# Print function arguments (optional args dict)
|
||||
x = inspect.currentframe().f_back # previous frame
|
||||
file, _, func, _, _ = inspect.getframeinfo(x)
|
||||
if args is None: # get args automatically
|
||||
args, _, _, frm = inspect.getargvalues(x)
|
||||
args = {k: v for k, v in frm.items() if k in args}
|
||||
try:
|
||||
file = Path(file).resolve().relative_to(ROOT).with_suffix('')
|
||||
except ValueError:
|
||||
file = Path(file).stem
|
||||
s = (f'{file}: ' if show_file else '') + (f'{func}: ' if show_func else '')
|
||||
LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
|
||||
|
@ -1,5 +1,6 @@
|
||||
import contextlib
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from zipfile import ZipFile
|
||||
|
||||
@ -61,3 +62,15 @@ def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX')):
|
||||
for f in zipObj.namelist(): # list all archived filenames in the zip
|
||||
if all(x not in f for x in exclude):
|
||||
zipObj.extract(f, path=path)
|
||||
|
||||
|
||||
def file_age(path=__file__):
|
||||
# Return days since last file update
|
||||
dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
|
||||
return dt.days # + dt.seconds / 86400 # fractional days
|
||||
|
||||
|
||||
def file_date(path=__file__):
|
||||
# Return human-readable file modification date, i.e. '2021-3-26'
|
||||
t = datetime.fromtimestamp(Path(path).stat().st_mtime)
|
||||
return f'{t.year}-{t.month}-{t.day}'
|
||||
|
@ -1,5 +1,6 @@
|
||||
import math
|
||||
import os
|
||||
import platform
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
@ -12,7 +13,9 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
import ultralytics
|
||||
from ultralytics.yolo.utils import LOGGER
|
||||
from ultralytics.yolo.utils.checks import git_describe
|
||||
|
||||
from .checks import check_version
|
||||
|
||||
@ -44,8 +47,8 @@ def DDP_model(model):
|
||||
|
||||
def select_device(device='', batch_size=0, newline=True):
|
||||
# device = None or 'cpu' or 0 or '0' or '0,1,2,3'
|
||||
# s = f'YOLOv5 🚀 {git_describe() or file_date()} Python-{platform.python_version()} torch-{torch.__version__} '
|
||||
s = f'YOLOv5 🚀 torch-{torch.__version__} '
|
||||
ver = git_describe() or ultralytics.__version__ # git commit or pip package version
|
||||
s = f'Ultralytics YOLO 🚀 {ver} Python-{platform.python_version()} torch-{torch.__version__} '
|
||||
device = str(device).strip().lower().replace('cuda:', '').replace('none', '') # to string, 'cuda:0' to '0'
|
||||
cpu = device == 'cpu'
|
||||
mps = device == 'mps' # Apple Metal Performance Shaders (MPS)
|
||||
@ -75,7 +78,7 @@ def select_device(device='', batch_size=0, newline=True):
|
||||
|
||||
if not newline:
|
||||
s = s.rstrip()
|
||||
print(s)
|
||||
LOGGER.info(s)
|
||||
return torch.device(arg)
|
||||
|
||||
|
||||
|
@ -1,7 +1,3 @@
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -10,7 +6,6 @@ import torch.nn.functional as F
|
||||
from ultralytics.yolo import v8
|
||||
from ultralytics.yolo.data import build_dataloader
|
||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer
|
||||
from ultralytics.yolo.utils.anchors import check_anchors
|
||||
from ultralytics.yolo.utils.metrics import FocalLoss, bbox_iou, smooth_BCE
|
||||
from ultralytics.yolo.utils.modeling.tasks import SegmentationModel
|
||||
from ultralytics.yolo.utils.ops import crop_mask, xywh2xyxy
|
||||
@ -24,7 +19,7 @@ class SegmentationTrainer(BaseTrainer):
|
||||
# TODO: manage splits differently
|
||||
# calculate stride - check if model is initialized
|
||||
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
||||
loader = build_dataloader(
|
||||
return build_dataloader(
|
||||
img_path=dataset_path,
|
||||
img_size=self.args.img_size,
|
||||
batch_size=batch_size,
|
||||
@ -38,18 +33,16 @@ class SegmentationTrainer(BaseTrainer):
|
||||
shuffle=self.args.shuffle,
|
||||
use_segments=True,
|
||||
)[0]
|
||||
return loader
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
|
||||
return batch
|
||||
|
||||
def load_model(self, model_cfg, weights, data):
|
||||
model = SegmentationModel(model_cfg if model_cfg else weights["model"].yaml,
|
||||
model = SegmentationModel(model_cfg or weights["model"].yaml,
|
||||
ch=3,
|
||||
nc=data["nc"],
|
||||
anchors=self.args.get("anchors"))
|
||||
check_anchors(model, self.args.anchor_t, self.args.img_size)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
return model
|
||||
|
@ -1,48 +0,0 @@
|
||||
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
||||
|
||||
# Parameters
|
||||
nc: 80 # number of classes
|
||||
depth_multiple: 0.33 # model depth multiple
|
||||
width_multiple: 0.25 # layer channel multiple
|
||||
anchors:
|
||||
- [10,13, 16,30, 33,23] # P3/8
|
||||
- [30,61, 62,45, 59,119] # P4/16
|
||||
- [116,90, 156,198, 373,326] # P5/32
|
||||
|
||||
# YOLOv5 v6.0 backbone
|
||||
backbone:
|
||||
# [from, number, module, args]
|
||||
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
|
||||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
|
||||
[-1, 3, C3, [128]],
|
||||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
|
||||
[-1, 6, C3, [256]],
|
||||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
|
||||
[-1, 9, C3, [512]],
|
||||
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
|
||||
[-1, 3, C3, [1024]],
|
||||
[-1, 1, SPPF, [1024, 5]], # 9
|
||||
]
|
||||
|
||||
# YOLOv5 v6.0 head
|
||||
head:
|
||||
[[-1, 1, Conv, [512, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
||||
[[-1, 6], 1, Concat, [1]], # cat backbone P4
|
||||
[-1, 3, C3, [512, False]], # 13
|
||||
|
||||
[-1, 1, Conv, [256, 1, 1]],
|
||||
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
||||
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
||||
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
|
||||
|
||||
[-1, 1, Conv, [256, 3, 2]],
|
||||
[[-1, 14], 1, Concat, [1]], # cat head P4
|
||||
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
|
||||
|
||||
[-1, 1, Conv, [512, 3, 2]],
|
||||
[[-1, 10], 1, Concat, [1]], # cat head P5
|
||||
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
|
||||
|
||||
[[17, 20, 23], 1, Segment, [nc, anchors, 32, 256]], # Detect(P3, P4, P5)
|
||||
]
|
Loading…
x
Reference in New Issue
Block a user