mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-24 14:44:21 +08:00
Add world_size check before setting up DDP train (#3191)
This commit is contained in:
parent
f8e1dcc43f
commit
0d91d6df6e
@ -217,7 +217,7 @@ class BaseTrainer:
|
|||||||
callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
|
callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
|
||||||
self.amp = torch.tensor(check_amp(self.model), device=self.device)
|
self.amp = torch.tensor(check_amp(self.model), device=self.device)
|
||||||
callbacks.default_callbacks = callbacks_backup # restore callbacks
|
callbacks.default_callbacks = callbacks_backup # restore callbacks
|
||||||
if RANK > -1: # DDP
|
if RANK > -1 and world_size > 1: # DDP
|
||||||
dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
|
dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
|
||||||
self.amp = bool(self.amp) # as boolean
|
self.amp = bool(self.amp) # as boolean
|
||||||
self.scaler = amp.GradScaler(enabled=self.amp)
|
self.scaler = amp.GradScaler(enabled=self.amp)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user