mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 05:15:49 +08:00
ddp resume checkpoint fix (#184)
This commit is contained in:
parent
36efe34fe1
commit
2c36ab0f10
@ -227,7 +227,7 @@ class BaseTrainer:
|
||||
# LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
||||
os.environ["NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout
|
||||
dist.init_process_group(
|
||||
"nccl" if dist.is_nccl_available() else "gloo",
|
||||
backend="nccl" if dist.is_nccl_available() else "gloo",
|
||||
timeout=timedelta(seconds=10800), # 3 hours
|
||||
rank=RANK,
|
||||
world_size=world_size,
|
||||
@ -645,8 +645,8 @@ class BaseTrainer:
|
||||
|
||||
resume = True
|
||||
self.args = get_cfg(ckpt_args)
|
||||
self.args.model = str(last) # reinstate model
|
||||
for k in "imgsz", "batch": # allow arg updates to reduce memory on resume if crashed due to CUDA OOM
|
||||
self.args.model = self.args.resume = str(last) # reinstate model
|
||||
for k in "imgsz", "batch", "device": # allow arg updates to reduce memory or update device on resume
|
||||
if k in overrides:
|
||||
setattr(self.args, k, overrides[k])
|
||||
|
||||
@ -669,14 +669,11 @@ class BaseTrainer:
|
||||
if self.ema and ckpt.get("ema"):
|
||||
self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA
|
||||
self.ema.updates = ckpt["updates"]
|
||||
if self.resume:
|
||||
assert start_epoch > 0, (
|
||||
f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
|
||||
f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
|
||||
)
|
||||
LOGGER.info(
|
||||
f"Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs"
|
||||
)
|
||||
assert start_epoch > 0, (
|
||||
f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
|
||||
f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
|
||||
)
|
||||
LOGGER.info(f"Resuming training {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs")
|
||||
if self.epochs < start_epoch:
|
||||
LOGGER.info(
|
||||
f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
|
||||
|
Loading…
x
Reference in New Issue
Block a user