mirror of
				https://github.com/THU-MIG/yolov10.git
				synced 2025-10-23 17:15:39 +08:00 
			
		
		
		
	Fix some cuda training issues of segmentation (#46)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									db1031a1a9
								
							
						
					
					
						commit
						47f1cb3ef4
					
				| @ -142,7 +142,7 @@ class BaseTrainer: | |||||||
|         self.train_loader = self.get_dataloader(self.trainset, batch_size=self.args.batch_size, rank=rank) |         self.train_loader = self.get_dataloader(self.trainset, batch_size=self.args.batch_size, rank=rank) | ||||||
|         if rank in {0, -1}: |         if rank in {0, -1}: | ||||||
|             print(" Creating testloader rank :", rank) |             print(" Creating testloader rank :", rank) | ||||||
|             self.test_loader = self.get_dataloader(self.testset, batch_size=self.args.batch_size * 2, rank=rank) |             self.test_loader = self.get_dataloader(self.testset, batch_size=self.args.batch_size * 2, rank=-1) | ||||||
|             self.validator = self.get_validator() |             self.validator = self.get_validator() | ||||||
|             print("created testloader :", rank) |             print("created testloader :", rank) | ||||||
|             self.console.info(self.progress_string()) |             self.console.info(self.progress_string()) | ||||||
| @ -150,6 +150,8 @@ class BaseTrainer: | |||||||
|     def _do_train(self, rank, world_size): |     def _do_train(self, rank, world_size): | ||||||
|         if world_size > 1: |         if world_size > 1: | ||||||
|             self._setup_ddp(rank, world_size) |             self._setup_ddp(rank, world_size) | ||||||
|  |         else: | ||||||
|  |             self.model = self.model.to(self.device) | ||||||
| 
 | 
 | ||||||
|         # callback hook. before_train |         # callback hook. before_train | ||||||
|         self._setup_train(rank) |         self._setup_train(rank) | ||||||
| @ -192,8 +194,8 @@ class BaseTrainer: | |||||||
|                 losses = tloss if loss_len > 1 else torch.unsqueeze(tloss, 0) |                 losses = tloss if loss_len > 1 else torch.unsqueeze(tloss, 0) | ||||||
|                 if rank in {-1, 0}: |                 if rank in {-1, 0}: | ||||||
|                     pbar.set_description( |                     pbar.set_description( | ||||||
|                         (" {} " + "{:.3f}  " * (2 + loss_len)).format(f'{epoch + 1}/{self.args.epochs}', mem, *losses, |                         (" {} " + "{:.3f}  " * (1 + loss_len) + ' {} ').format(f'{epoch + 1}/{self.args.epochs}', mem, | ||||||
|                                                                       batch["img"].shape[-1])) |                                                                                *losses, batch["img"].shape[-1])) | ||||||
| 
 | 
 | ||||||
|             if rank in [-1, 0]: |             if rank in [-1, 0]: | ||||||
|                 # validation |                 # validation | ||||||
| @ -286,7 +288,8 @@ class BaseTrainer: | |||||||
|         "fitness" metric. |         "fitness" metric. | ||||||
|         """ |         """ | ||||||
|         self.metrics = self.validator(self) |         self.metrics = self.validator(self) | ||||||
|         self.fitness = self.metrics.get("fitness") or (-self.loss)  # use loss as fitness measure if not found |         self.fitness = self.metrics.get("fitness", | ||||||
|  |                                         -self.loss.detach().cpu().numpy())  # use loss as fitness measure if not found | ||||||
|         if not self.best_fitness or self.best_fitness < self.fitness: |         if not self.best_fitness or self.best_fitness < self.fitness: | ||||||
|             self.best_fitness = self.fitness |             self.best_fitness = self.fitness | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -6,7 +6,7 @@ from tqdm import tqdm | |||||||
| 
 | 
 | ||||||
| from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG | from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG | ||||||
| from ultralytics.yolo.utils.ops import Profile | from ultralytics.yolo.utils.ops import Profile | ||||||
| from ultralytics.yolo.utils.torch_utils import select_device | from ultralytics.yolo.utils.torch_utils import de_parallel, select_device | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class BaseValidator: | class BaseValidator: | ||||||
| @ -36,7 +36,9 @@ class BaseValidator: | |||||||
|         if training: |         if training: | ||||||
|             model = trainer.model |             model = trainer.model | ||||||
|             self.args.half &= self.device.type != 'cpu' |             self.args.half &= self.device.type != 'cpu' | ||||||
|             model = model.half() if self.args.half else model |             # NOTE: half() inference in evaluation will make training stuck, | ||||||
|  |             # so I comment it out for now, I think we can reuse half mode after we add EMA. | ||||||
|  |             # model = model.half() if self.args.half else model | ||||||
|         else:  # TODO: handle this when detectMultiBackend is supported |         else:  # TODO: handle this when detectMultiBackend is supported | ||||||
|             # model = DetectMultiBacked(model) |             # model = DetectMultiBacked(model) | ||||||
|             pass |             pass | ||||||
| @ -48,8 +50,8 @@ class BaseValidator: | |||||||
|         n_batches = len(self.dataloader) |         n_batches = len(self.dataloader) | ||||||
|         desc = self.get_desc() |         desc = self.get_desc() | ||||||
|         bar = tqdm(self.dataloader, desc, n_batches, not training, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') |         bar = tqdm(self.dataloader, desc, n_batches, not training, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') | ||||||
|         self.init_metrics(model) |         self.init_metrics(de_parallel(model)) | ||||||
|         with torch.cuda.amp.autocast(enabled=self.device.type != 'cpu'): |         with torch.no_grad(): | ||||||
|             for batch_i, batch in enumerate(bar): |             for batch_i, batch in enumerate(bar): | ||||||
|                 self.batch_i = batch_i |                 self.batch_i = batch_i | ||||||
|                 # pre-process |                 # pre-process | ||||||
| @ -58,7 +60,7 @@ class BaseValidator: | |||||||
| 
 | 
 | ||||||
|                 # inference |                 # inference | ||||||
|                 with dt[1]: |                 with dt[1]: | ||||||
|                     preds = model(batch["img"]) |                     preds = model(batch["img"].float()) | ||||||
|                     # TODO: remember to add native augmentation support when implementing model, like: |                     # TODO: remember to add native augmentation support when implementing model, like: | ||||||
|                     #  preds, train_out = model(im, augment=augment) |                     #  preds, train_out = model(im, augment=augment) | ||||||
| 
 | 
 | ||||||
| @ -85,6 +87,8 @@ class BaseValidator: | |||||||
|             self.logger.info( |             self.logger.info( | ||||||
|                 'Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image at shape ' % t) |                 'Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image at shape ' % t) | ||||||
| 
 | 
 | ||||||
|  |         if self.training: | ||||||
|  |             model.float() | ||||||
|         # TODO: implement save json |         # TODO: implement save json | ||||||
| 
 | 
 | ||||||
|         return stats |         return stats | ||||||
|  | |||||||
| @ -6,10 +6,11 @@ from ultralytics.yolo.engine.validator import BaseValidator | |||||||
| class ClassificationValidator(BaseValidator): | class ClassificationValidator(BaseValidator): | ||||||
| 
 | 
 | ||||||
|     def init_metrics(self, model): |     def init_metrics(self, model): | ||||||
|         self.correct = torch.tensor([]) |         self.correct = torch.tensor([], device=next(model.parameters()).device) | ||||||
| 
 | 
 | ||||||
|     def preprocess(self, batch): |     def preprocess(self, batch): | ||||||
|         batch["img"] = batch["img"].to(self.device) |         batch["img"] = batch["img"].to(self.device, non_blocking=True) | ||||||
|  |         batch["img"] = batch["img"].half() if self.args.half else batch["img"].float() | ||||||
|         batch["cls"] = batch["cls"].to(self.device) |         batch["cls"] = batch["cls"].to(self.device) | ||||||
|         return batch |         return batch | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -23,7 +23,7 @@ class SegmentationTrainer(BaseTrainer): | |||||||
|     def get_dataloader(self, dataset_path, batch_size, rank=0): |     def get_dataloader(self, dataset_path, batch_size, rank=0): | ||||||
|         # TODO: manage splits differently |         # TODO: manage splits differently | ||||||
|         # calculate stride - check if model is initialized |         # calculate stride - check if model is initialized | ||||||
|         gs = max(int(self.model.stride.max() if self.model else 0), 32) |         gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32) | ||||||
|         loader = build_dataloader( |         loader = build_dataloader( | ||||||
|             img_path=dataset_path, |             img_path=dataset_path, | ||||||
|             img_size=self.args.img_size, |             img_size=self.args.img_size, | ||||||
| @ -220,7 +220,7 @@ class SegmentationTrainer(BaseTrainer): | |||||||
|                 mxyxy = xywh2xyxy(xywhn[i] * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device)) |                 mxyxy = xywh2xyxy(xywhn[i] * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device)) | ||||||
|                 for bi in b.unique(): |                 for bi in b.unique(): | ||||||
|                     j = b == bi  # matching index |                     j = b == bi  # matching index | ||||||
|                     if True: |                     if self.args.overlap_mask: | ||||||
|                         mask_gti = torch.where(masks[bi][None] == tidxs[i][j].view(-1, 1, 1), 1.0, 0.0) |                         mask_gti = torch.where(masks[bi][None] == tidxs[i][j].view(-1, 1, 1), 1.0, 0.0) | ||||||
|                     else: |                     else: | ||||||
|                         mask_gti = masks[tidxs[i]][j] |                         mask_gti = masks[tidxs[i]][j] | ||||||
|  | |||||||
| @ -30,11 +30,13 @@ class SegmentationValidator(BaseValidator): | |||||||
| 
 | 
 | ||||||
|     def preprocess(self, batch): |     def preprocess(self, batch): | ||||||
|         batch["img"] = batch["img"].to(self.device, non_blocking=True) |         batch["img"] = batch["img"].to(self.device, non_blocking=True) | ||||||
|         batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 225 |         batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255 | ||||||
|         batch["bboxes"] = batch["bboxes"].to(self.device) |  | ||||||
|         batch["masks"] = batch["masks"].to(self.device).float() |         batch["masks"] = batch["masks"].to(self.device).float() | ||||||
|         self.nb, _, self.height, self.width = batch["img"].shape  # batch size, channels, height, width |         self.nb, _, self.height, self.width = batch["img"].shape  # batch size, channels, height, width | ||||||
|         self.targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1) |         self.targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1) | ||||||
|  |         self.targets = self.targets.to(self.device) | ||||||
|  |         height, width = batch["img"].shape[2:] | ||||||
|  |         self.targets[:, 2:] *= torch.tensor((width, height, width, height), device=self.device)  # to pixels | ||||||
|         self.lb = [self.targets[self.targets[:, 0] == i, 1:] |         self.lb = [self.targets[self.targets[:, 0] == i, 1:] | ||||||
|                    for i in range(self.nb)] if self.args.save_hybrid else []  # for autolabelling |                    for i in range(self.nb)] if self.args.save_hybrid else []  # for autolabelling | ||||||
| 
 | 
 | ||||||
| @ -75,7 +77,7 @@ class SegmentationValidator(BaseValidator): | |||||||
|                                     agnostic=self.args.single_cls, |                                     agnostic=self.args.single_cls, | ||||||
|                                     max_det=self.args.max_det, |                                     max_det=self.args.max_det, | ||||||
|                                     nm=self.nm) |                                     nm=self.nm) | ||||||
|         return (p, preds[0], preds[2]) |         return (p, preds[1], preds[2]) | ||||||
| 
 | 
 | ||||||
|     def update_metrics(self, preds, batch): |     def update_metrics(self, preds, batch): | ||||||
|         # Metrics |         # Metrics | ||||||
| @ -83,7 +85,7 @@ class SegmentationValidator(BaseValidator): | |||||||
|         for si, (pred, proto) in enumerate(zip(preds[0], preds[1])): |         for si, (pred, proto) in enumerate(zip(preds[0], preds[1])): | ||||||
|             labels = self.targets[self.targets[:, 0] == si, 1:] |             labels = self.targets[self.targets[:, 0] == si, 1:] | ||||||
|             nl, npr = labels.shape[0], pred.shape[0]  # number of labels, predictions |             nl, npr = labels.shape[0], pred.shape[0]  # number of labels, predictions | ||||||
|             shape = Path(batch["im_file"][si]) |             shape = batch["shape"][si] | ||||||
|             # path = batch["shape"][si][0] |             # path = batch["shape"][si][0] | ||||||
|             correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device)  # init |             correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device)  # init | ||||||
|             correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device)  # init |             correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device)  # init | ||||||
| @ -106,22 +108,29 @@ class SegmentationValidator(BaseValidator): | |||||||
|             if self.args.single_cls: |             if self.args.single_cls: | ||||||
|                 pred[:, 5] = 0 |                 pred[:, 5] = 0 | ||||||
|             predn = pred.clone() |             predn = pred.clone() | ||||||
|             ops.scale_boxes(batch["img"][si].shape[1:], predn[:, :4], shape, batch["shape"][si][1])  # native-space pred |             ops.scale_boxes(batch["img"][si].shape[1:], predn[:, :4], shape)  # native-space pred | ||||||
| 
 | 
 | ||||||
|             # Evaluate |             # Evaluate | ||||||
|             if nl: |             if nl: | ||||||
|                 tbox = ops.xywh2xyxy(labels[:, 1:5])  # target boxes |                 tbox = ops.xywh2xyxy(labels[:, 1:5])  # target boxes | ||||||
|                 ops.scale_boxes(batch["img"][si].shape[1:], tbox, shape, batch["shapes"][si][1])  # native-space labels |                 ops.scale_boxes(batch["img"][si].shape[1:], tbox, shape)  # native-space labels | ||||||
|                 labelsn = torch.cat((labels[:, 0:1], tbox), 1)  # native-space labels |                 labelsn = torch.cat((labels[:, 0:1], tbox), 1)  # native-space labels | ||||||
|                 correct_bboxes = self._process_batch(predn, labelsn, self.iouv) |                 correct_bboxes = self._process_batch(predn, labelsn, self.iouv) | ||||||
|                 correct_masks = self._process_batch(predn, labelsn, self.iouv, pred_masks, gt_masks, masks=True) |                 # TODO: maybe remove these `self.` arguments as they already are member variable | ||||||
|  |                 correct_masks = self._process_batch(predn, | ||||||
|  |                                                     labelsn, | ||||||
|  |                                                     self.iouv, | ||||||
|  |                                                     pred_masks, | ||||||
|  |                                                     gt_masks, | ||||||
|  |                                                     overlap=self.args.overlap_mask, | ||||||
|  |                                                     masks=True) | ||||||
|                 if self.args.plots: |                 if self.args.plots: | ||||||
|                     self.confusion_matrix.process_batch(predn, labelsn) |                     self.confusion_matrix.process_batch(predn, labelsn) | ||||||
|             self.stats.append((correct_masks, correct_bboxes, pred[:, 4], pred[:, 5], labels[:, |             self.stats.append((correct_masks, correct_bboxes, pred[:, 4], pred[:, 5], labels[:, | ||||||
|                                                                                              0]))  # (conf, pcls, tcls) |                                                                                              0]))  # (conf, pcls, tcls) | ||||||
| 
 | 
 | ||||||
|             pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8) |             pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8) | ||||||
|             if self.plots and self.batch_i < 3: |             if self.args.plots and self.batch_i < 3: | ||||||
|                 plot_masks.append(pred_masks[:15].cpu())  # filter top 15 to plot |                 plot_masks.append(pred_masks[:15].cpu())  # filter top 15 to plot | ||||||
| 
 | 
 | ||||||
|             # TODO: Save/log |             # TODO: Save/log | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Laughing
						Laughing