mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-24 06:14:55 +08:00
Fix some cuda training issues of segmentation (#46)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
db1031a1a9
commit
47f1cb3ef4
@ -142,7 +142,7 @@ class BaseTrainer:
|
|||||||
self.train_loader = self.get_dataloader(self.trainset, batch_size=self.args.batch_size, rank=rank)
|
self.train_loader = self.get_dataloader(self.trainset, batch_size=self.args.batch_size, rank=rank)
|
||||||
if rank in {0, -1}:
|
if rank in {0, -1}:
|
||||||
print(" Creating testloader rank :", rank)
|
print(" Creating testloader rank :", rank)
|
||||||
self.test_loader = self.get_dataloader(self.testset, batch_size=self.args.batch_size * 2, rank=rank)
|
self.test_loader = self.get_dataloader(self.testset, batch_size=self.args.batch_size * 2, rank=-1)
|
||||||
self.validator = self.get_validator()
|
self.validator = self.get_validator()
|
||||||
print("created testloader :", rank)
|
print("created testloader :", rank)
|
||||||
self.console.info(self.progress_string())
|
self.console.info(self.progress_string())
|
||||||
@ -150,6 +150,8 @@ class BaseTrainer:
|
|||||||
def _do_train(self, rank, world_size):
|
def _do_train(self, rank, world_size):
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
self._setup_ddp(rank, world_size)
|
self._setup_ddp(rank, world_size)
|
||||||
|
else:
|
||||||
|
self.model = self.model.to(self.device)
|
||||||
|
|
||||||
# callback hook. before_train
|
# callback hook. before_train
|
||||||
self._setup_train(rank)
|
self._setup_train(rank)
|
||||||
@ -192,8 +194,8 @@ class BaseTrainer:
|
|||||||
losses = tloss if loss_len > 1 else torch.unsqueeze(tloss, 0)
|
losses = tloss if loss_len > 1 else torch.unsqueeze(tloss, 0)
|
||||||
if rank in {-1, 0}:
|
if rank in {-1, 0}:
|
||||||
pbar.set_description(
|
pbar.set_description(
|
||||||
(" {} " + "{:.3f} " * (2 + loss_len)).format(f'{epoch + 1}/{self.args.epochs}', mem, *losses,
|
(" {} " + "{:.3f} " * (1 + loss_len) + ' {} ').format(f'{epoch + 1}/{self.args.epochs}', mem,
|
||||||
batch["img"].shape[-1]))
|
*losses, batch["img"].shape[-1]))
|
||||||
|
|
||||||
if rank in [-1, 0]:
|
if rank in [-1, 0]:
|
||||||
# validation
|
# validation
|
||||||
@ -286,7 +288,8 @@ class BaseTrainer:
|
|||||||
"fitness" metric.
|
"fitness" metric.
|
||||||
"""
|
"""
|
||||||
self.metrics = self.validator(self)
|
self.metrics = self.validator(self)
|
||||||
self.fitness = self.metrics.get("fitness") or (-self.loss) # use loss as fitness measure if not found
|
self.fitness = self.metrics.get("fitness",
|
||||||
|
-self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
|
||||||
if not self.best_fitness or self.best_fitness < self.fitness:
|
if not self.best_fitness or self.best_fitness < self.fitness:
|
||||||
self.best_fitness = self.fitness
|
self.best_fitness = self.fitness
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
||||||
from ultralytics.yolo.utils.ops import Profile
|
from ultralytics.yolo.utils.ops import Profile
|
||||||
from ultralytics.yolo.utils.torch_utils import select_device
|
from ultralytics.yolo.utils.torch_utils import de_parallel, select_device
|
||||||
|
|
||||||
|
|
||||||
class BaseValidator:
|
class BaseValidator:
|
||||||
@ -36,7 +36,9 @@ class BaseValidator:
|
|||||||
if training:
|
if training:
|
||||||
model = trainer.model
|
model = trainer.model
|
||||||
self.args.half &= self.device.type != 'cpu'
|
self.args.half &= self.device.type != 'cpu'
|
||||||
model = model.half() if self.args.half else model
|
# NOTE: half() inference in evaluation will make training stuck,
|
||||||
|
# so I comment it out for now, I think we can reuse half mode after we add EMA.
|
||||||
|
# model = model.half() if self.args.half else model
|
||||||
else: # TODO: handle this when detectMultiBackend is supported
|
else: # TODO: handle this when detectMultiBackend is supported
|
||||||
# model = DetectMultiBacked(model)
|
# model = DetectMultiBacked(model)
|
||||||
pass
|
pass
|
||||||
@ -48,8 +50,8 @@ class BaseValidator:
|
|||||||
n_batches = len(self.dataloader)
|
n_batches = len(self.dataloader)
|
||||||
desc = self.get_desc()
|
desc = self.get_desc()
|
||||||
bar = tqdm(self.dataloader, desc, n_batches, not training, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
|
bar = tqdm(self.dataloader, desc, n_batches, not training, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
|
||||||
self.init_metrics(model)
|
self.init_metrics(de_parallel(model))
|
||||||
with torch.cuda.amp.autocast(enabled=self.device.type != 'cpu'):
|
with torch.no_grad():
|
||||||
for batch_i, batch in enumerate(bar):
|
for batch_i, batch in enumerate(bar):
|
||||||
self.batch_i = batch_i
|
self.batch_i = batch_i
|
||||||
# pre-process
|
# pre-process
|
||||||
@ -58,7 +60,7 @@ class BaseValidator:
|
|||||||
|
|
||||||
# inference
|
# inference
|
||||||
with dt[1]:
|
with dt[1]:
|
||||||
preds = model(batch["img"])
|
preds = model(batch["img"].float())
|
||||||
# TODO: remember to add native augmentation support when implementing model, like:
|
# TODO: remember to add native augmentation support when implementing model, like:
|
||||||
# preds, train_out = model(im, augment=augment)
|
# preds, train_out = model(im, augment=augment)
|
||||||
|
|
||||||
@ -85,6 +87,8 @@ class BaseValidator:
|
|||||||
self.logger.info(
|
self.logger.info(
|
||||||
'Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image at shape ' % t)
|
'Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image at shape ' % t)
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
model.float()
|
||||||
# TODO: implement save json
|
# TODO: implement save json
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
@ -6,10 +6,11 @@ from ultralytics.yolo.engine.validator import BaseValidator
|
|||||||
class ClassificationValidator(BaseValidator):
|
class ClassificationValidator(BaseValidator):
|
||||||
|
|
||||||
def init_metrics(self, model):
|
def init_metrics(self, model):
|
||||||
self.correct = torch.tensor([])
|
self.correct = torch.tensor([], device=next(model.parameters()).device)
|
||||||
|
|
||||||
def preprocess(self, batch):
|
def preprocess(self, batch):
|
||||||
batch["img"] = batch["img"].to(self.device)
|
batch["img"] = batch["img"].to(self.device, non_blocking=True)
|
||||||
|
batch["img"] = batch["img"].half() if self.args.half else batch["img"].float()
|
||||||
batch["cls"] = batch["cls"].to(self.device)
|
batch["cls"] = batch["cls"].to(self.device)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ class SegmentationTrainer(BaseTrainer):
|
|||||||
def get_dataloader(self, dataset_path, batch_size, rank=0):
|
def get_dataloader(self, dataset_path, batch_size, rank=0):
|
||||||
# TODO: manage splits differently
|
# TODO: manage splits differently
|
||||||
# calculate stride - check if model is initialized
|
# calculate stride - check if model is initialized
|
||||||
gs = max(int(self.model.stride.max() if self.model else 0), 32)
|
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
||||||
loader = build_dataloader(
|
loader = build_dataloader(
|
||||||
img_path=dataset_path,
|
img_path=dataset_path,
|
||||||
img_size=self.args.img_size,
|
img_size=self.args.img_size,
|
||||||
@ -220,7 +220,7 @@ class SegmentationTrainer(BaseTrainer):
|
|||||||
mxyxy = xywh2xyxy(xywhn[i] * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device))
|
mxyxy = xywh2xyxy(xywhn[i] * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device))
|
||||||
for bi in b.unique():
|
for bi in b.unique():
|
||||||
j = b == bi # matching index
|
j = b == bi # matching index
|
||||||
if True:
|
if self.args.overlap_mask:
|
||||||
mask_gti = torch.where(masks[bi][None] == tidxs[i][j].view(-1, 1, 1), 1.0, 0.0)
|
mask_gti = torch.where(masks[bi][None] == tidxs[i][j].view(-1, 1, 1), 1.0, 0.0)
|
||||||
else:
|
else:
|
||||||
mask_gti = masks[tidxs[i]][j]
|
mask_gti = masks[tidxs[i]][j]
|
||||||
|
@ -30,11 +30,13 @@ class SegmentationValidator(BaseValidator):
|
|||||||
|
|
||||||
def preprocess(self, batch):
|
def preprocess(self, batch):
|
||||||
batch["img"] = batch["img"].to(self.device, non_blocking=True)
|
batch["img"] = batch["img"].to(self.device, non_blocking=True)
|
||||||
batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 225
|
batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
|
||||||
batch["bboxes"] = batch["bboxes"].to(self.device)
|
|
||||||
batch["masks"] = batch["masks"].to(self.device).float()
|
batch["masks"] = batch["masks"].to(self.device).float()
|
||||||
self.nb, _, self.height, self.width = batch["img"].shape # batch size, channels, height, width
|
self.nb, _, self.height, self.width = batch["img"].shape # batch size, channels, height, width
|
||||||
self.targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
self.targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
||||||
|
self.targets = self.targets.to(self.device)
|
||||||
|
height, width = batch["img"].shape[2:]
|
||||||
|
self.targets[:, 2:] *= torch.tensor((width, height, width, height), device=self.device) # to pixels
|
||||||
self.lb = [self.targets[self.targets[:, 0] == i, 1:]
|
self.lb = [self.targets[self.targets[:, 0] == i, 1:]
|
||||||
for i in range(self.nb)] if self.args.save_hybrid else [] # for autolabelling
|
for i in range(self.nb)] if self.args.save_hybrid else [] # for autolabelling
|
||||||
|
|
||||||
@ -75,7 +77,7 @@ class SegmentationValidator(BaseValidator):
|
|||||||
agnostic=self.args.single_cls,
|
agnostic=self.args.single_cls,
|
||||||
max_det=self.args.max_det,
|
max_det=self.args.max_det,
|
||||||
nm=self.nm)
|
nm=self.nm)
|
||||||
return (p, preds[0], preds[2])
|
return (p, preds[1], preds[2])
|
||||||
|
|
||||||
def update_metrics(self, preds, batch):
|
def update_metrics(self, preds, batch):
|
||||||
# Metrics
|
# Metrics
|
||||||
@ -83,7 +85,7 @@ class SegmentationValidator(BaseValidator):
|
|||||||
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
|
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
|
||||||
labels = self.targets[self.targets[:, 0] == si, 1:]
|
labels = self.targets[self.targets[:, 0] == si, 1:]
|
||||||
nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions
|
nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions
|
||||||
shape = Path(batch["im_file"][si])
|
shape = batch["shape"][si]
|
||||||
# path = batch["shape"][si][0]
|
# path = batch["shape"][si][0]
|
||||||
correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
||||||
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
||||||
@ -106,22 +108,29 @@ class SegmentationValidator(BaseValidator):
|
|||||||
if self.args.single_cls:
|
if self.args.single_cls:
|
||||||
pred[:, 5] = 0
|
pred[:, 5] = 0
|
||||||
predn = pred.clone()
|
predn = pred.clone()
|
||||||
ops.scale_boxes(batch["img"][si].shape[1:], predn[:, :4], shape, batch["shape"][si][1]) # native-space pred
|
ops.scale_boxes(batch["img"][si].shape[1:], predn[:, :4], shape) # native-space pred
|
||||||
|
|
||||||
# Evaluate
|
# Evaluate
|
||||||
if nl:
|
if nl:
|
||||||
tbox = ops.xywh2xyxy(labels[:, 1:5]) # target boxes
|
tbox = ops.xywh2xyxy(labels[:, 1:5]) # target boxes
|
||||||
ops.scale_boxes(batch["img"][si].shape[1:], tbox, shape, batch["shapes"][si][1]) # native-space labels
|
ops.scale_boxes(batch["img"][si].shape[1:], tbox, shape) # native-space labels
|
||||||
labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
|
labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
|
||||||
correct_bboxes = self._process_batch(predn, labelsn, self.iouv)
|
correct_bboxes = self._process_batch(predn, labelsn, self.iouv)
|
||||||
correct_masks = self._process_batch(predn, labelsn, self.iouv, pred_masks, gt_masks, masks=True)
|
# TODO: maybe remove these `self.` arguments as they already are member variable
|
||||||
|
correct_masks = self._process_batch(predn,
|
||||||
|
labelsn,
|
||||||
|
self.iouv,
|
||||||
|
pred_masks,
|
||||||
|
gt_masks,
|
||||||
|
overlap=self.args.overlap_mask,
|
||||||
|
masks=True)
|
||||||
if self.args.plots:
|
if self.args.plots:
|
||||||
self.confusion_matrix.process_batch(predn, labelsn)
|
self.confusion_matrix.process_batch(predn, labelsn)
|
||||||
self.stats.append((correct_masks, correct_bboxes, pred[:, 4], pred[:, 5], labels[:,
|
self.stats.append((correct_masks, correct_bboxes, pred[:, 4], pred[:, 5], labels[:,
|
||||||
0])) # (conf, pcls, tcls)
|
0])) # (conf, pcls, tcls)
|
||||||
|
|
||||||
pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
|
pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
|
||||||
if self.plots and self.batch_i < 3:
|
if self.args.plots and self.batch_i < 3:
|
||||||
plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot
|
plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot
|
||||||
|
|
||||||
# TODO: Save/log
|
# TODO: Save/log
|
||||||
|
Loading…
x
Reference in New Issue
Block a user