This commit is contained in:
wa22 2024-05-24 05:40:00 +00:00
parent 742c7151b5
commit b714fa8bae
2 changed files with 2 additions and 1 deletions

View File

@ -14,6 +14,7 @@ batch: 16 # (int) number of images per batch (-1 for AutoBatch)
imgsz: 640 # (int | list) input images size as int for train and val modes, or list[w,h] for predict and export modes imgsz: 640 # (int | list) input images size as int for train and val modes, or list[w,h] for predict and export modes
save: True # (bool) save train checkpoints and predict results save: True # (bool) save train checkpoints and predict results
save_period: -1 # (int) Save checkpoint every x epochs (disabled if < 1) save_period: -1 # (int) Save checkpoint every x epochs (disabled if < 1)
val_period: 10 # (int) Validation every x epochs
cache: False # (bool) True/ram, disk or False. Use cache for data loading cache: False # (bool) True/ram, disk or False. Use cache for data loading
device: # (int | str | list, optional) device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu device: # (int | str | list, optional) device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu
workers: 8 # (int) number of worker threads for data loading (per RANK if DDP) workers: 8 # (int) number of worker threads for data loading (per RANK if DDP)

View File

@ -425,7 +425,7 @@ class BaseTrainer:
self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"]) self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
# Validation # Validation
if (self.args.val and (((epoch+1) % 10 == 0) or (self.epochs - epoch) <= 10)) \ if (self.args.val and (((epoch+1) % self.args.val_period == 0) or (self.epochs - epoch) <= 10)) \
or final_epoch or self.stopper.possible_stop or self.stop: or final_epoch or self.stopper.possible_stop or self.stop:
self.metrics, self.fitness = self.validate() self.metrics, self.fitness = self.validate()
self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr}) self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})