diff --git a/.gitignore b/.gitignore index 0854267a..1463e020 100644 --- a/.gitignore +++ b/.gitignore @@ -51,6 +51,7 @@ coverage.xml .hypothesis/ .pytest_cache/ mlruns/ +figures/ # Translations *.mo diff --git a/run_attribution.py b/run_attribution.py new file mode 100644 index 00000000..8c87e592 --- /dev/null +++ b/run_attribution.py @@ -0,0 +1,34 @@ +from ultralytics import YOLOv10, YOLO +# from ultralytics.engine.pgt_trainer import PGTTrainer +# from ultralytics import BaseTrainer +# from ultralytics.engine.trainer import BaseTrainer +import os + +# Set CUDA device (only needed for multi-gpu machines) +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"] = "4" + +# model = YOLOv10() +# model = YOLO() +# If you want to finetune the model with pretrained weights, you could load the +# pretrained weights like below +# model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}') +# or +# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt +model = YOLOv10('yolov10n.pt') + +model.train(data='coco.yaml', + trainer=model._smart_load("pgt_trainer"), # This is needed to generate attributions (will be used later to train via PGT) + # Add return_images as input parameter + epochs=500, batch=16, imgsz=640, + debug=True, # If debug = True, the attributions will be saved in the figures folder + ) + +# Save the trained model +model.save('yolov10_coco_trained.pt') + +# Evaluate the model on the validation set +results = model.val(data='coco.yaml') + +# Print the evaluation results +print(results) \ No newline at end of file diff --git a/run_train.py b/run_train.py new file mode 100644 index 00000000..012d9d83 --- /dev/null +++ b/run_train.py @@ -0,0 +1,83 @@ +from ultralytics import YOLOv10, YOLO +# from ultralytics.engine.pgt_trainer import PGTTrainer +# from ultralytics import BaseTrainer +# from ultralytics.engine.trainer import BaseTrainer +import os + +# Set CUDA device (only needed for multi-gpu machines) +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"] = "4" + +model = YOLOv10() +# model = YOLO() +# If you want to finetune the model with pretrained weights, you could load the +# pretrained weights like below +# model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}') +# or +# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt +# model = YOLOv10('yolov10m.pt') + +model.train(data='coco.yaml', + # Add return_images as input parameter + epochs=500, batch=16, imgsz=640, + ) + +# Save the trained model +model.save('yolov10_coco_trained.pt') + +# Evaluate the model on the validation set +results = model.val(data='coco.yaml') + +# Print the evaluation results +print(results) + +# import torch +# from torch.utils.data import DataLoader +# from torchvision import datasets, transforms + +# # Define the transformation for the dataset +# transform = transforms.Compose([ +# transforms.Resize((640, 640)), +# transforms.ToTensor() +# ]) + +# # Load the COCO dataset +# train_dataset = datasets.CocoDetection(root='data/nielseni6/coco/train2017', annFile='/data/nielseni6/coco/annotations/instances_train2017.json', transform=transform) +# val_dataset = datasets.CocoDetection(root='data/nielseni6/coco/val2017', annFile='/data/nielseni6/coco/annotations/instances_val2017.json', transform=transform) + +# # Create data loaders +# train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4) +# val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=4) + +# model = YOLOv10() + +# # Define the optimizer +# optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + +# # Training loop +# for epoch in range(500): +# model.train() +# for images, targets in train_loader: +# images = images.to('cuda') +# targets = [{k: v.to('cuda') for k, v in t.items()} for t in targets] +# loss = model(images, targets) +# loss.backward() +# optimizer.step() +# optimizer.zero_grad() + +# # Validation loop +# model.eval() +# with torch.no_grad(): +# for images, targets in val_loader: +# images = images.to('cuda') +# targets = [{k: v.to('cuda') for k, v in t.items()} for t in targets] +# results = model(images, targets) + +# # Save the trained model +# model.save('yolov10_coco_trained.pt') + +# # Evaluate the model on the validation set +# results = model.val(data='coco.yaml') + +# # Print the evaluation results +# print(results) \ No newline at end of file diff --git a/run_val.py b/run_val.py new file mode 100644 index 00000000..6003d5ee --- /dev/null +++ b/run_val.py @@ -0,0 +1,96 @@ +from ultralytics import YOLOv10 +import torch +from PIL import Image +from torchvision import transforms + +# Define the device +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +# model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}') +# model = YOLOv10.from_pretrained('jameslahm/yolov10n') +# or +# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt +# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10n.pt +# model = YOLOv10('yolov10{n/s/m/b/l/x}.pt') +model = YOLOv10('yolov10n.pt').to(device) + +# Load the image +# path = '/home/nielseni6/PythonScripts/Github/yolov10/images/fat-dog.jpg' +path = '/home/nielseni6/PythonScripts/Github/yolov10/images/The-Cardinal-Bird.jpg' +image = Image.open(path) + +# Define the transformation to resize the image, convert it to a tensor, and normalize it +transform = transforms.Compose([ + transforms.Resize((640, 640)), + transforms.ToTensor(), + # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) +]) + +# Apply the transformation +image_tensor = transform(image) + +# Add a batch dimension +image_tensor = image_tensor.unsqueeze(0).to(device) +image_tensor = image_tensor.requires_grad_(True) + + +# Predict for a specific image +# results = model.predict(image_tensor, save=True) +# model.requires_grad_(True) + + +# for p in model.parameters(): +# p.requires_grad = True +results = model.predict(image_tensor, save=True) + +# Display the results +for result in results: + print(result) + +# pred = results[0].boxes[0].conf + +# # Hook to store the activations +# activations = {} + +# def get_activation(name): +# def hook(model, input, output): +# activations[name] = output +# return hook + +# # Register hooks for each layer you want to inspect +# for name, layer in model.model.named_modules(): +# layer.register_forward_hook(get_activation(name)) + +# # Run the model to get activations +# results = model.predict(image_tensor, save=True, visualize=True) + +# # # Print the activations +# # for name, activation in activations.items(): +# # print(f"Activation from layer {name}: {activation}") + +# # List activation names separately +# print("\nActivation layer names:") +# for name in activations.keys(): +# print(name) +# # pred.backward() + +# # Assuming 'model.23' is the layer of interest for bbox prediction and confidence +# activation = activations['model.23']['one2one'][0] +# act_23 = activations['model.23.cv3.2'] +# act_dfl = activations['model.23.dfl.conv'] +# act_conv = activations['model.0.conv'] +# act_act = activations['model.0.act'] + +# # with torch.autograd.set_detect_anomaly(True): +# # pred.backward() +# grad = torch.autograd.grad(act_23, im, grad_outputs=torch.ones_like(act_23), create_graph=True, retain_graph=True)[0] +# # grad = torch.autograd.grad(pred, im, grad_outputs=torch.ones_like(pred), create_graph=True)[0] +# grad = torch.autograd.grad(activations['model.23']['one2one'][1][0], +# activations['model.23.one2one_cv3.2'], +# grad_outputs=torch.ones_like(activations['model.23']['one2one'][1][0]), +# create_graph=True, retain_graph=True)[0] + +# # Print the results +# print(results) + +# model.val(data='coco.yaml', batch=256) \ No newline at end of file diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py index ef5c93c0..be87559c 100644 --- a/ultralytics/engine/model.py +++ b/ultralytics/engine/model.py @@ -387,6 +387,7 @@ class Model(nn.Module): source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None, stream: bool = False, predictor=None, + return_images: bool = False, **kwargs, ) -> list: """ @@ -438,7 +439,7 @@ class Model(nn.Module): self.predictor.save_dir = get_save_dir(self.predictor.args) if prompts and hasattr(self.predictor, "set_prompts"): # for SAM-type models self.predictor.set_prompts(prompts) - return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream) + return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream, return_images=return_images) def track( self, @@ -590,6 +591,81 @@ class Model(nn.Module): return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model) def train( + self, + trainer=None, + debug=False, + **kwargs, + ): + """ + Trains the model using the specified dataset and training configuration. + + This method facilitates model training with a range of customizable settings and configurations. It supports + training with a custom trainer or the default training approach defined in the method. The method handles + different scenarios, such as resuming training from a checkpoint, integrating with Ultralytics HUB, and + updating model and configuration after training. + + When using Ultralytics HUB, if the session already has a loaded model, the method prioritizes HUB training + arguments and issues a warning if local arguments are provided. It checks for pip updates and combines default + configurations, method-specific defaults, and user-provided arguments to configure the training process. After + training, it updates the model and its configurations, and optionally attaches metrics. + + Args: + trainer (BaseTrainer, optional): An instance of a custom trainer class for training the model. If None, the + method uses a default trainer. Defaults to None. + **kwargs (any): Arbitrary keyword arguments representing the training configuration. These arguments are + used to customize various aspects of the training process. + + Returns: + (dict | None): Training metrics if available and training is successful; otherwise, None. + + Raises: + AssertionError: If the model is not a PyTorch model. + PermissionError: If there is a permission issue with the HUB session. + ModuleNotFoundError: If the HUB SDK is not installed. + """ + self._check_is_pytorch_model() + if hasattr(self.session, "model") and self.session.model.id: # Ultralytics HUB session with loaded model + if any(kwargs): + LOGGER.warning("WARNING ⚠️ using HUB training arguments, ignoring local training arguments.") + kwargs = self.session.train_args # overwrite kwargs + + checks.check_pip_update_available() + + overrides = yaml_load(checks.check_yaml(kwargs["cfg"])) if kwargs.get("cfg") else self.overrides + custom = {"data": DEFAULT_CFG_DICT["data"] or TASK2DATA[self.task]} # method defaults + args = {**overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right + if args.get("resume"): + args["resume"] = self.ckpt_path + + self.trainer = (trainer or self._smart_load("trainer"))(overrides=args, _callbacks=self.callbacks) + if not args.get("resume"): # manually set model only if not resuming + self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml) + self.model = self.trainer.model + + if SETTINGS["hub"] is True and not self.session: + # Create a model in HUB + try: + self.session = self._get_hub_session(self.model_name) + if self.session: + self.session.create_model(args) + # Check model was created + if not getattr(self.session.model, "id", None): + self.session = None + except (PermissionError, ModuleNotFoundError): + # Ignore PermissionError and ModuleNotFoundError which indicates hub-sdk not installed + pass + + self.trainer.hub_session = self.session # attach optional HUB session + self.trainer.train(debug=debug) + # Update model and cfg after training + if RANK in (-1, 0): + ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last + self.model, _ = attempt_load_one_weight(ckpt) + self.overrides = self.model.args + self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP + return self.metrics + + def train_pgt( self, trainer=None, **kwargs, @@ -662,7 +738,7 @@ class Model(nn.Module): self.overrides = self.model.args self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP return self.metrics - + def tune( self, use_ray=False, diff --git a/ultralytics/engine/pgt_trainer.py b/ultralytics/engine/pgt_trainer.py new file mode 100644 index 00000000..3d99ad41 --- /dev/null +++ b/ultralytics/engine/pgt_trainer.py @@ -0,0 +1,785 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +""" +Train a model on a dataset. + +Usage: + $ yolo mode=train model=yolov8n.pt data=coco128.yaml imgsz=640 epochs=100 batch=16 +""" + +import math +import os +import subprocess +import time +import warnings +from copy import deepcopy +from datetime import datetime, timedelta +from pathlib import Path + +import numpy as np +import torch +from torch import distributed as dist +from torch import nn, optim + +import matplotlib.pyplot as plt +import torchvision.transforms as T + +from ultralytics.cfg import get_cfg, get_save_dir +from ultralytics.data.utils import check_cls_dataset, check_det_dataset +from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights +from ultralytics.utils import ( + DEFAULT_CFG, + LOGGER, + RANK, + TQDM, + __version__, + callbacks, + clean_url, + colorstr, + emojis, + yaml_save, +) +from ultralytics.utils.autobatch import check_train_batch_size +from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args +from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command +from ultralytics.utils.files import get_latest_run +from ultralytics.utils.torch_utils import ( + EarlyStopping, + ModelEMA, + de_parallel, + init_seeds, + one_cycle, + select_device, + strip_optimizer, +) + + +class PGTTrainer: + """ + BaseTrainer. + + A base class for creating trainers. + + Attributes: + args (SimpleNamespace): Configuration for the trainer. + validator (BaseValidator): Validator instance. + model (nn.Module): Model instance. + callbacks (defaultdict): Dictionary of callbacks. + save_dir (Path): Directory to save results. + wdir (Path): Directory to save weights. + last (Path): Path to the last checkpoint. + best (Path): Path to the best checkpoint. + save_period (int): Save checkpoint every x epochs (disabled if < 1). + batch_size (int): Batch size for training. + epochs (int): Number of epochs to train for. + start_epoch (int): Starting epoch for training. + device (torch.device): Device to use for training. + amp (bool): Flag to enable AMP (Automatic Mixed Precision). + scaler (amp.GradScaler): Gradient scaler for AMP. + data (str): Path to data. + trainset (torch.utils.data.Dataset): Training dataset. + testset (torch.utils.data.Dataset): Testing dataset. + ema (nn.Module): EMA (Exponential Moving Average) of the model. + resume (bool): Resume training from a checkpoint. + lf (nn.Module): Loss function. + scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler. + best_fitness (float): The best fitness value achieved. + fitness (float): Current fitness value. + loss (float): Current loss value. + tloss (float): Total loss value. + loss_names (list): List of loss names. + csv (Path): Path to results CSV file. + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """ + Initializes the BaseTrainer class. + + Args: + cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG. + overrides (dict, optional): Configuration overrides. Defaults to None. + """ + self.args = get_cfg(cfg, overrides) + self.check_resume(overrides) + self.device = select_device(self.args.device, self.args.batch) + self.validator = None + self.metrics = None + self.plots = {} + init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic) + + # Dirs + self.save_dir = get_save_dir(self.args) + self.args.name = self.save_dir.name # update name for loggers + self.wdir = self.save_dir / "weights" # weights dir + if RANK in (-1, 0): + self.wdir.mkdir(parents=True, exist_ok=True) # make dir + self.args.save_dir = str(self.save_dir) + yaml_save(self.save_dir / "args.yaml", vars(self.args)) # save run args + self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt" # checkpoint paths + self.save_period = self.args.save_period + + self.batch_size = self.args.batch + self.epochs = self.args.epochs + self.start_epoch = 0 + if RANK == -1: + print_args(vars(self.args)) + + # Device + if self.device.type in ("cpu", "mps"): + self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading + + # Model and Dataset + self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt + try: + if self.args.task == "classify": + self.data = check_cls_dataset(self.args.data) + elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in ( + "detect", + "segment", + "pose", + "obb", + ): + self.data = check_det_dataset(self.args.data) + if "yaml_file" in self.data: + self.args.data = self.data["yaml_file"] # for validating 'yolo train data=url.zip' usage + except Exception as e: + raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e + + self.trainset, self.testset = self.get_dataset(self.data) + self.ema = None + + # Optimization utils init + self.lf = None + self.scheduler = None + + # Epoch level metrics + self.best_fitness = None + self.fitness = None + self.loss = None + self.tloss = None + self.loss_names = ["Loss"] + self.csv = self.save_dir / "results.csv" + self.plot_idx = [0, 1, 2] + + # Callbacks + self.callbacks = _callbacks or callbacks.get_default_callbacks() + if RANK in (-1, 0): + callbacks.add_integration_callbacks(self) + + def add_callback(self, event: str, callback): + """Appends the given callback.""" + self.callbacks[event].append(callback) + + def set_callback(self, event: str, callback): + """Overrides the existing callbacks with the given callback.""" + self.callbacks[event] = [callback] + + def run_callbacks(self, event: str): + """Run all existing callbacks associated with a particular event.""" + for callback in self.callbacks.get(event, []): + callback(self) + + def train(self, debug=False): + """Allow device='', device=None on Multi-GPU systems to default to device=0.""" + if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3' + world_size = len(self.args.device.split(",")) + elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list) + world_size = len(self.args.device) + elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number + world_size = 1 # default to device 0 + else: # i.e. device='cpu' or 'mps' + world_size = 0 + + # Run subprocess if DDP training, else train normally + if world_size > 1 and "LOCAL_RANK" not in os.environ: + # Argument checks + if self.args.rect: + LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'") + self.args.rect = False + if self.args.batch == -1: + LOGGER.warning( + "WARNING ⚠️ 'batch=-1' for AutoBatch is incompatible with Multi-GPU training, setting " + "default 'batch=16'" + ) + self.args.batch = 16 + + # Command + cmd, file = generate_ddp_command(world_size, self) + try: + LOGGER.info(f'{colorstr("DDP:")} debug command {" ".join(cmd)}') + subprocess.run(cmd, check=True) + except Exception as e: + raise e + finally: + ddp_cleanup(self, str(file)) + + else: + self._do_train(world_size, debug=debug) + + def _setup_scheduler(self): + """Initialize training learning rate scheduler.""" + if self.args.cos_lr: + self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf'] + else: + self.lf = lambda x: max(1 - x / self.epochs, 0) * (1.0 - self.args.lrf) + self.args.lrf # linear + self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf) + + def _setup_ddp(self, world_size): + """Initializes and sets the DistributedDataParallel parameters for training.""" + torch.cuda.set_device(RANK) + self.device = torch.device("cuda", RANK) + # LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}') + os.environ["NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout + dist.init_process_group( + backend="nccl" if dist.is_nccl_available() else "gloo", + timeout=timedelta(seconds=10800), # 3 hours + rank=RANK, + world_size=world_size, + ) + + def _setup_train(self, world_size): + """Builds dataloaders and optimizer on correct rank process.""" + + # Model + self.run_callbacks("on_pretrain_routine_start") + ckpt = self.setup_model() + self.model = self.model.to(self.device) + self.set_model_attributes() + + # Freeze layers + freeze_list = ( + self.args.freeze + if isinstance(self.args.freeze, list) + else range(self.args.freeze) + if isinstance(self.args.freeze, int) + else [] + ) + always_freeze_names = [".dfl"] # always freeze these layers + freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names + for k, v in self.model.named_parameters(): + # v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results) + if any(x in k for x in freeze_layer_names): + LOGGER.info(f"Freezing layer '{k}'") + v.requires_grad = False + elif not v.requires_grad and v.dtype.is_floating_point: # only floating point Tensor can require gradients + LOGGER.info( + f"WARNING ⚠️ setting 'requires_grad=True' for frozen layer '{k}'. " + "See ultralytics.engine.trainer for customization of frozen layers." + ) + v.requires_grad = True + + # Check AMP + self.amp = torch.tensor(self.args.amp).to(self.device) # True or False + if self.amp and RANK in (-1, 0): # Single-GPU and DDP + callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them + self.amp = torch.tensor(check_amp(self.model), device=self.device) + callbacks.default_callbacks = callbacks_backup # restore callbacks + if RANK > -1 and world_size > 1: # DDP + dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None) + self.amp = bool(self.amp) # as boolean + self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp) + if world_size > 1: + self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK]) + + # Check imgsz + gs = max(int(self.model.stride.max() if hasattr(self.model, "stride") else 32), 32) # grid size (max stride) + self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1) + self.stride = gs # for multiscale training + + # Batch size + if self.batch_size == -1 and RANK == -1: # single-GPU only, estimate best batch size + self.args.batch = self.batch_size = check_train_batch_size(self.model, self.args.imgsz, self.amp) + + # Dataloaders + batch_size = self.batch_size // max(world_size, 1) + self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode="train") + if RANK in (-1, 0): + # Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects. + self.test_loader = self.get_dataloader( + self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val" + ) + self.validator = self.get_validator() + metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val") + self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) + self.ema = ModelEMA(self.model) + if self.args.plots: + self.plot_training_labels() + + # Optimizer + self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing + weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay + iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs + self.optimizer = self.build_optimizer( + model=self.model, + name=self.args.optimizer, + lr=self.args.lr0, + momentum=self.args.momentum, + decay=weight_decay, + iterations=iterations, + ) + # Scheduler + self._setup_scheduler() + self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False + self.resume_training(ckpt) + self.scheduler.last_epoch = self.start_epoch - 1 # do not move + self.run_callbacks("on_pretrain_routine_end") + + def _do_train(self, world_size=1, debug=False): + """Train completed, evaluate and plot if specified by arguments.""" + if world_size > 1: + self._setup_ddp(world_size) + self._setup_train(world_size) + + nb = len(self.train_loader) # number of batches + nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 # warmup iterations + last_opt_step = -1 + self.epoch_time = None + self.epoch_time_start = time.time() + self.train_time_start = time.time() + self.run_callbacks("on_train_start") + LOGGER.info( + f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n' + f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n' + f"Logging results to {colorstr('bold', self.save_dir)}\n" + f'Starting training for ' + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...") + ) + if self.args.close_mosaic: + base_idx = (self.epochs - self.args.close_mosaic) * nb + self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2]) + epoch = self.start_epoch + while True: + self.epoch = epoch + self.run_callbacks("on_train_epoch_start") + self.model.train() + if RANK != -1: + self.train_loader.sampler.set_epoch(epoch) + pbar = enumerate(self.train_loader) + # Update dataloader attributes (optional) + if epoch == (self.epochs - self.args.close_mosaic): + self._close_dataloader_mosaic() + self.train_loader.reset() + + if RANK in (-1, 0): + LOGGER.info(self.progress_string()) + pbar = TQDM(enumerate(self.train_loader), total=nb) + self.tloss = None + self.optimizer.zero_grad() + for i, batch in pbar: + self.run_callbacks("on_train_batch_start") + # Warmup + ni = i + nb * epoch + if ni <= nw: + xi = [0, nw] # x interp + self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round())) + for j, x in enumerate(self.optimizer.param_groups): + # Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0 + x["lr"] = np.interp( + ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x["initial_lr"] * self.lf(epoch)] + ) + if "momentum" in x: + x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum]) + + # Forward + with torch.cuda.amp.autocast(self.amp): + batch = self.preprocess_batch(batch) + (self.loss, self.loss_items), images = self.model(batch, return_images=True) + + if debug and (i % 250): + grad = torch.autograd.grad(self.loss, images, create_graph=True)[0] + # Convert tensors to numpy arrays + images_np = images.detach().cpu().numpy().transpose(0, 2, 3, 1) + grad_np = grad.detach().cpu().numpy().transpose(0, 2, 3, 1) + + # Normalize grad for visualization + grad_np = (grad_np - grad_np.min()) / (grad_np.max() - grad_np.min()) + + for ix in range(images_np.shape[0]): + fig, ax = plt.subplots(1, 3, figsize=(15, 5)) + ax[0].imshow(images_np[i]) + ax[0].set_title('Image') + ax[1].imshow(grad_np[i], cmap='jet') + ax[1].set_title('Gradient') + ax[2].imshow(images_np[i]) + ax[2].imshow(grad_np[i], cmap='jet', alpha=0.5) + ax[2].set_title('Overlay') + + save_dir_attr = "figures/attributions" + if not os.path.exists(save_dir_attr): + os.makedirs(save_dir_attr) + plt.savefig(f'{save_dir_attr}/debug_epoch_{epoch}_batch_{i}_image_{ix}.png') + plt.close(fig) + + if RANK != -1: + self.loss *= world_size + self.tloss = ( + (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items + ) + + # Backward + self.scaler.scale(self.loss).backward() + + # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html + if ni - last_opt_step >= self.accumulate: + self.optimizer_step() + last_opt_step = ni + + # Timed stopping + if self.args.time: + self.stop = (time.time() - self.train_time_start) > (self.args.time * 3600) + if RANK != -1: # if DDP training + broadcast_list = [self.stop if RANK == 0 else None] + dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks + self.stop = broadcast_list[0] + if self.stop: # training time exceeded + break + + # Log + mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" # (GB) + loss_len = self.tloss.shape[0] if len(self.tloss.shape) else 1 + losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0) + if RANK in (-1, 0): + pbar.set_description( + ("%11s" * 2 + "%11.4g" * (2 + loss_len)) + % (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1]) + ) + self.run_callbacks("on_batch_end") + if self.args.plots and ni in self.plot_idx: + self.plot_training_samples(batch, ni) + + self.run_callbacks("on_train_batch_end") + + self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers + self.run_callbacks("on_train_epoch_end") + if RANK in (-1, 0): + final_epoch = epoch + 1 == self.epochs + self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"]) + + # Validation + if (self.args.val and (((epoch+1) % self.args.val_period == 0) or (self.epochs - epoch) <= 10)) \ + or final_epoch or self.stopper.possible_stop or self.stop: + self.metrics, self.fitness = self.validate() + self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr}) + self.stop |= self.stopper(epoch + 1, self.fitness) or final_epoch + if self.args.time: + self.stop |= (time.time() - self.train_time_start) > (self.args.time * 3600) + + # Save model + if self.args.save or final_epoch: + self.save_model() + self.run_callbacks("on_model_save") + + # Scheduler + t = time.time() + self.epoch_time = t - self.epoch_time_start + self.epoch_time_start = t + with warnings.catch_warnings(): + warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()' + if self.args.time: + mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1) + self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time) + self._setup_scheduler() + self.scheduler.last_epoch = self.epoch # do not move + self.stop |= epoch >= self.epochs # stop if exceeded epochs + self.scheduler.step() + self.run_callbacks("on_fit_epoch_end") + torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors + + # Early Stopping + if RANK != -1: # if DDP training + broadcast_list = [self.stop if RANK == 0 else None] + dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks + self.stop = broadcast_list[0] + if self.stop: + break # must break all DDP ranks + epoch += 1 + + if RANK in (-1, 0): + # Do final val with best.pt + LOGGER.info( + f"\n{epoch - self.start_epoch + 1} epochs completed in " + f"{(time.time() - self.train_time_start) / 3600:.3f} hours." + ) + self.final_eval() + if self.args.plots: + self.plot_metrics() + self.run_callbacks("on_train_end") + torch.cuda.empty_cache() + self.run_callbacks("teardown") + + def save_model(self): + """Save model training checkpoints with additional metadata.""" + import pandas as pd # scope for faster startup + + metrics = {**self.metrics, **{"fitness": self.fitness}} + results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()} + ckpt = { + "epoch": self.epoch, + "best_fitness": self.best_fitness, + "model": deepcopy(de_parallel(self.model)).half(), + "ema": deepcopy(self.ema.ema).half(), + "updates": self.ema.updates, + "optimizer": self.optimizer.state_dict(), + "train_args": vars(self.args), # save as dict + "train_metrics": metrics, + "train_results": results, + "date": datetime.now().isoformat(), + "version": __version__, + "license": "AGPL-3.0 (https://ultralytics.com/license)", + "docs": "https://docs.ultralytics.com", + } + + # Save last and best + torch.save(ckpt, self.last) + if self.best_fitness == self.fitness: + torch.save(ckpt, self.best) + if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0): + torch.save(ckpt, self.wdir / f"epoch{self.epoch}.pt") + + @staticmethod + def get_dataset(data): + """ + Get train, val path from data dict if it exists. + + Returns None if data format is not recognized. + """ + return data["train"], data.get("val") or data.get("test") + + def setup_model(self): + """Load/create/download model for any task.""" + if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed + return + + model, weights = self.model, None + ckpt = None + if str(model).endswith(".pt"): + weights, ckpt = attempt_load_one_weight(model) + cfg = ckpt["model"].yaml + else: + cfg = model + self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights) + return ckpt + + def optimizer_step(self): + """Perform a single step of the training optimizer with gradient clipping and EMA update.""" + self.scaler.unscale_(self.optimizer) # unscale gradients + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) # clip gradients + self.scaler.step(self.optimizer) + self.scaler.update() + self.optimizer.zero_grad() + if self.ema: + self.ema.update(self.model) + + def preprocess_batch(self, batch): + """Allows custom preprocessing model inputs and ground truths depending on task type.""" + return batch + + def validate(self): + """ + Runs validation on test set using self.validator. + + The returned dict is expected to contain "fitness" key. + """ + metrics = self.validator(self) + fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found + if not self.best_fitness or self.best_fitness < fitness: + self.best_fitness = fitness + return metrics, fitness + + def get_model(self, cfg=None, weights=None, verbose=True): + """Get model and raise NotImplementedError for loading cfg files.""" + raise NotImplementedError("This task trainer doesn't support loading cfg files") + + def get_validator(self): + """Returns a NotImplementedError when the get_validator function is called.""" + raise NotImplementedError("get_validator function not implemented in trainer") + + def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"): + """Returns dataloader derived from torch.data.Dataloader.""" + raise NotImplementedError("get_dataloader function not implemented in trainer") + + def build_dataset(self, img_path, mode="train", batch=None): + """Build dataset.""" + raise NotImplementedError("build_dataset function not implemented in trainer") + + def label_loss_items(self, loss_items=None, prefix="train"): + """ + Returns a loss dict with labelled training loss items tensor. + + Note: + This is not needed for classification but necessary for segmentation & detection + """ + return {"loss": loss_items} if loss_items is not None else ["loss"] + + def set_model_attributes(self): + """To set or update model parameters before training.""" + self.model.names = self.data["names"] + + def build_targets(self, preds, targets): + """Builds target tensors for training YOLO model.""" + pass + + def progress_string(self): + """Returns a string describing training progress.""" + return "" + + # TODO: may need to put these following functions into callback + def plot_training_samples(self, batch, ni): + """Plots training samples during YOLO training.""" + pass + + def plot_training_labels(self): + """Plots training labels for YOLO model.""" + pass + + def save_metrics(self, metrics): + """Saves training metrics to a CSV file.""" + keys, vals = list(metrics.keys()), list(metrics.values()) + n = len(metrics) + 1 # number of cols + s = "" if self.csv.exists() else (("%23s," * n % tuple(["epoch"] + keys)).rstrip(",") + "\n") # header + with open(self.csv, "a") as f: + f.write(s + ("%23.5g," * n % tuple([self.epoch + 1] + vals)).rstrip(",") + "\n") + + def plot_metrics(self): + """Plot and display metrics visually.""" + pass + + def on_plot(self, name, data=None): + """Registers plots (e.g. to be consumed in callbacks)""" + path = Path(name) + self.plots[path] = {"data": data, "timestamp": time.time()} + + def final_eval(self): + """Performs final evaluation and validation for object detection YOLO model.""" + for f in self.last, self.best: + if f.exists(): + strip_optimizer(f) # strip optimizers + if f is self.best: + LOGGER.info(f"\nValidating {f}...") + self.validator.args.plots = self.args.plots + self.metrics = self.validator(model=f) + self.metrics.pop("fitness", None) + self.run_callbacks("on_fit_epoch_end") + + def check_resume(self, overrides): + """Check if resume checkpoint exists and update arguments accordingly.""" + resume = self.args.resume + if resume: + try: + exists = isinstance(resume, (str, Path)) and Path(resume).exists() + last = Path(check_file(resume) if exists else get_latest_run()) + + # Check that resume data YAML exists, otherwise strip to force re-download of dataset + ckpt_args = attempt_load_weights(last).args + if not Path(ckpt_args["data"]).exists(): + ckpt_args["data"] = self.args.data + + resume = True + self.args = get_cfg(ckpt_args) + self.args.model = self.args.resume = str(last) # reinstate model + for k in "imgsz", "batch", "device": # allow arg updates to reduce memory or update device on resume + if k in overrides: + setattr(self.args, k, overrides[k]) + + except Exception as e: + raise FileNotFoundError( + "Resume checkpoint not found. Please pass a valid checkpoint to resume from, " + "i.e. 'yolo train resume model=path/to/last.pt'" + ) from e + self.resume = resume + + def resume_training(self, ckpt): + """Resume YOLO training from given epoch and best fitness.""" + if ckpt is None or not self.resume: + return + best_fitness = 0.0 + start_epoch = ckpt["epoch"] + 1 + if ckpt["optimizer"] is not None: + self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer + best_fitness = ckpt["best_fitness"] + if self.ema and ckpt.get("ema"): + self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA + self.ema.updates = ckpt["updates"] + assert start_epoch > 0, ( + f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n" + f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'" + ) + LOGGER.info(f"Resuming training {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs") + if self.epochs < start_epoch: + LOGGER.info( + f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs." + ) + self.epochs += ckpt["epoch"] # finetune additional epochs + self.best_fitness = best_fitness + self.start_epoch = start_epoch + if start_epoch > (self.epochs - self.args.close_mosaic): + self._close_dataloader_mosaic() + + def _close_dataloader_mosaic(self): + """Update dataloaders to stop using mosaic augmentation.""" + if hasattr(self.train_loader.dataset, "mosaic"): + self.train_loader.dataset.mosaic = False + if hasattr(self.train_loader.dataset, "close_mosaic"): + LOGGER.info("Closing dataloader mosaic") + self.train_loader.dataset.close_mosaic(hyp=self.args) + + def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5): + """ + Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum, + weight decay, and number of iterations. + + Args: + model (torch.nn.Module): The model for which to build an optimizer. + name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected + based on the number of iterations. Default: 'auto'. + lr (float, optional): The learning rate for the optimizer. Default: 0.001. + momentum (float, optional): The momentum factor for the optimizer. Default: 0.9. + decay (float, optional): The weight decay for the optimizer. Default: 1e-5. + iterations (float, optional): The number of iterations, which determines the optimizer if + name is 'auto'. Default: 1e5. + + Returns: + (torch.optim.Optimizer): The constructed optimizer. + """ + + g = [], [], [] # optimizer parameter groups + bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d() + if name == "auto": + LOGGER.info( + f"{colorstr('optimizer:')} 'optimizer=auto' found, " + f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and " + f"determining best 'optimizer', 'lr0' and 'momentum' automatically... " + ) + nc = getattr(model, "nc", 10) # number of classes + lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places + name, lr, momentum = ("SGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9) + self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam + + for module_name, module in model.named_modules(): + for param_name, param in module.named_parameters(recurse=False): + fullname = f"{module_name}.{param_name}" if module_name else param_name + if "bias" in fullname: # bias (no decay) + g[2].append(param) + elif isinstance(module, bn): # weight (no decay) + g[1].append(param) + else: # weight (with decay) + g[0].append(param) + + if name in ("Adam", "Adamax", "AdamW", "NAdam", "RAdam"): + optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0) + elif name == "RMSProp": + optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum) + elif name == "SGD": + optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True) + else: + raise NotImplementedError( + f"Optimizer '{name}' not found in list of available optimizers " + f"[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto]." + "To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics." + ) + + optimizer.add_param_group({"params": g[0], "weight_decay": decay}) # add g0 with weight_decay + optimizer.add_param_group({"params": g[1], "weight_decay": 0.0}) # add g1 (BatchNorm2d weights) + LOGGER.info( + f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups " + f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)' + ) + return optimizer diff --git a/ultralytics/engine/predictor.py b/ultralytics/engine/predictor.py index 9ec803a7..62cc59f8 100644 --- a/ultralytics/engine/predictor.py +++ b/ultralytics/engine/predictor.py @@ -206,7 +206,7 @@ class BasePredictor: self.vid_writer = {} @smart_inference_mode() - def stream_inference(self, source=None, model=None, *args, **kwargs): + def stream_inference(self, source=None, model=None, return_images = False, *args, **kwargs): """Streams real-time inference on camera feed and saves results to file.""" if self.args.verbose: LOGGER.info("") @@ -243,6 +243,9 @@ class BasePredictor: with profilers[0]: im = self.preprocess(im0s) + if return_images: + im = im.requires_grad_(True) + # Inference with profilers[1]: preds = self.inference(im, *args, **kwargs) @@ -272,7 +275,7 @@ class BasePredictor: LOGGER.info("\n".join(s)) self.run_callbacks("on_predict_batch_end") - yield from self.results + yield from (self.results, im) # Release assets for v in self.vid_writer.values(): diff --git a/ultralytics/models/yolo/detect/__init__.py b/ultralytics/models/yolo/detect/__init__.py index 5f3e62c1..b4cb6dce 100644 --- a/ultralytics/models/yolo/detect/__init__.py +++ b/ultralytics/models/yolo/detect/__init__.py @@ -1,7 +1,8 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license from .predict import DetectionPredictor +from .pgt_train import PGTDetectionTrainer from .train import DetectionTrainer from .val import DetectionValidator -__all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator" +__all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator", "PGTDetectionTrainer" diff --git a/ultralytics/models/yolo/detect/pgt_train.py b/ultralytics/models/yolo/detect/pgt_train.py new file mode 100644 index 00000000..be960193 --- /dev/null +++ b/ultralytics/models/yolo/detect/pgt_train.py @@ -0,0 +1,144 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import math +import random +from copy import copy + +import numpy as np +import torch.nn as nn + +from ultralytics.data import build_dataloader, build_yolo_dataset +from ultralytics.engine.trainer import BaseTrainer +from ultralytics.engine.pgt_trainer import PGTTrainer +from ultralytics.models import yolo +from ultralytics.nn.tasks import DetectionModel +from ultralytics.utils import LOGGER, RANK +from ultralytics.utils.plotting import plot_images, plot_labels, plot_results +from ultralytics.utils.torch_utils import de_parallel, torch_distributed_zero_first + + +class PGTDetectionTrainer(PGTTrainer): + """ + A class extending the BaseTrainer class for training based on a detection model. + + Example: + ```python + from ultralytics.models.yolo.detect import DetectionTrainer + + args = dict(model='yolov8n.pt', data='coco8.yaml', epochs=3) + trainer = DetectionTrainer(overrides=args) + trainer.train() + ``` + """ + + def build_dataset(self, img_path, mode="train", batch=None): + """ + Build YOLO Dataset. + + Args: + img_path (str): Path to the folder containing images. + mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode. + batch (int, optional): Size of batches, this is for `rect`. Defaults to None. + """ + gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32) + return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs) + + def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"): + """Construct and return dataloader.""" + assert mode in ["train", "val"] + with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP + dataset = self.build_dataset(dataset_path, mode, batch_size) + shuffle = mode == "train" + if getattr(dataset, "rect", False) and shuffle: + LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False") + shuffle = False + workers = self.args.workers if mode == "train" else self.args.workers * 2 + return build_dataloader(dataset, batch_size, workers, shuffle, rank) # return dataloader + + def preprocess_batch(self, batch): + """Preprocesses a batch of images by scaling and converting to float.""" + batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255 + if self.args.multi_scale: + imgs = batch["img"] + sz = ( + random.randrange(self.args.imgsz * 0.5, self.args.imgsz * 1.5 + self.stride) + // self.stride + * self.stride + ) # size + sf = sz / max(imgs.shape[2:]) # scale factor + if sf != 1: + ns = [ + math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:] + ] # new shape (stretched to gs-multiple) + imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False) + batch["img"] = imgs + return batch + + def set_model_attributes(self): + """Nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps).""" + # self.args.box *= 3 / nl # scale to layers + # self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers + # self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers + self.model.nc = self.data["nc"] # attach number of classes to model + self.model.names = self.data["names"] # attach class names to model + self.model.args = self.args # attach hyperparameters to model + # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc + + def get_model(self, cfg=None, weights=None, verbose=True): + """Return a YOLO detection model.""" + model = DetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1) + if weights: + model.load(weights) + return model + + def get_validator(self): + """Returns a DetectionValidator for YOLO model validation.""" + self.loss_names = "box_loss", "cls_loss", "dfl_loss" + return yolo.detect.DetectionValidator( + self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks + ) + + def label_loss_items(self, loss_items=None, prefix="train"): + """ + Returns a loss dict with labelled training loss items tensor. + + Not needed for classification but necessary for segmentation & detection + """ + keys = [f"{prefix}/{x}" for x in self.loss_names] + if loss_items is not None: + loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats + return dict(zip(keys, loss_items)) + else: + return keys + + def progress_string(self): + """Returns a formatted string of training progress with epoch, GPU memory, loss, instances and size.""" + return ("\n" + "%11s" * (4 + len(self.loss_names))) % ( + "Epoch", + "GPU_mem", + *self.loss_names, + "Instances", + "Size", + ) + + def plot_training_samples(self, batch, ni): + """Plots training samples with their annotations.""" + plot_images( + images=batch["img"], + batch_idx=batch["batch_idx"], + cls=batch["cls"].squeeze(-1), + bboxes=batch["bboxes"], + paths=batch["im_file"], + fname=self.save_dir / f"train_batch{ni}.jpg", + on_plot=self.on_plot, + ) + + def plot_metrics(self): + """Plots metrics from a CSV file.""" + plot_results(file=self.csv, on_plot=self.on_plot) # save results.png + + def plot_training_labels(self): + """Create a labeled training plot of the YOLO model.""" + boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0) + cls = np.concatenate([lb["cls"] for lb in self.train_loader.dataset.labels], 0) + plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot) diff --git a/ultralytics/models/yolov10/model.py b/ultralytics/models/yolov10/model.py index 09592c83..6e29c3d0 100644 --- a/ultralytics/models/yolov10/model.py +++ b/ultralytics/models/yolov10/model.py @@ -3,6 +3,8 @@ from ultralytics.nn.tasks import YOLOv10DetectionModel from .val import YOLOv10DetectionValidator from .predict import YOLOv10DetectionPredictor from .train import YOLOv10DetectionTrainer +from .pgt_train import YOLOv10PGTDetectionTrainer +# from .pgt_trainer import YOLOv10DetectionTrainer from huggingface_hub import PyTorchModelHubMixin from .card import card_template_text @@ -30,6 +32,7 @@ class YOLOv10(Model, PyTorchModelHubMixin, model_card_template=card_template_tex "detect": { "model": YOLOv10DetectionModel, "trainer": YOLOv10DetectionTrainer, + "pgt_trainer": YOLOv10PGTDetectionTrainer, "validator": YOLOv10DetectionValidator, "predictor": YOLOv10DetectionPredictor, }, diff --git a/ultralytics/models/yolov10/pgt_train.py b/ultralytics/models/yolov10/pgt_train.py new file mode 100644 index 00000000..1ce70145 --- /dev/null +++ b/ultralytics/models/yolov10/pgt_train.py @@ -0,0 +1,21 @@ +from ultralytics.models.yolo.detect import DetectionTrainer +from ultralytics.models.yolo.detect import PGTDetectionTrainer +from .val import YOLOv10DetectionValidator +from .model import YOLOv10DetectionModel +from copy import copy +from ultralytics.utils import RANK + +class YOLOv10PGTDetectionTrainer(PGTDetectionTrainer): + def get_validator(self): + """Returns a DetectionValidator for YOLO model validation.""" + self.loss_names = "box_om", "cls_om", "dfl_om", "box_oo", "cls_oo", "dfl_oo", + return YOLOv10DetectionValidator( + self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks + ) + + def get_model(self, cfg=None, weights=None, verbose=True): + """Return a YOLO detection model.""" + model = YOLOv10DetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1) + if weights: + model.load(weights) + return model diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index 268bd125..00fbadff 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -93,7 +93,7 @@ class BaseModel(nn.Module): return self.loss(x, *args, **kwargs) return self.predict(x, *args, **kwargs) - def predict(self, x, profile=False, visualize=False, augment=False, embed=None): + def predict(self, x, profile=False, visualize=False, augment=False, embed=None, return_images=False): """ Perform a forward pass through the network. @@ -107,9 +107,12 @@ class BaseModel(nn.Module): Returns: (torch.Tensor): The last output of the model. """ + if return_images: + x = x.requires_grad_(True) if augment: return self._predict_augment(x) - return self._predict_once(x, profile, visualize, embed) + out = self._predict_once(x, profile, visualize, embed) + return (out, x) if return_images else out def _predict_once(self, x, profile=False, visualize=False, embed=None): """ @@ -140,13 +143,13 @@ class BaseModel(nn.Module): return torch.unbind(torch.cat(embeddings, 1), dim=0) return x - def _predict_augment(self, x): + def _predict_augment(self, x, *args, **kwargs): """Perform augmentations on input image x and return augmented inference.""" LOGGER.warning( f"WARNING ⚠️ {self.__class__.__name__} does not support augmented inference yet. " f"Reverting to single-scale inference instead." ) - return self._predict_once(x) + return self._predict_once(x, *args, **kwargs) def _profile_one_layer(self, m, x, dt): """ @@ -260,7 +263,7 @@ class BaseModel(nn.Module): if verbose: LOGGER.info(f"Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights") - def loss(self, batch, preds=None): + def loss(self, batch, preds=None, return_images=False): """ Compute loss. @@ -271,8 +274,12 @@ class BaseModel(nn.Module): if not hasattr(self, "criterion"): self.criterion = self.init_criterion() - preds = self.forward(batch["img"]) if preds is None else preds - return self.criterion(preds, batch) + preds = self.forward(batch["img"], return_images=return_images) if preds is None else preds + if return_images: + preds, im = preds + loss = self.criterion(preds, batch) + out = loss if not return_images else (loss, im) + return out def init_criterion(self): """Initialize the loss criterion for the BaseModel."""