mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Hydra *.yml extension deprecated fix (#34)
This commit is contained in:
parent
6fe8bead35
commit
e4f7458d90
@ -24,7 +24,7 @@ from ultralytics.yolo.utils import LOGGER, ROOT
|
|||||||
from ultralytics.yolo.utils.files import increment_path, save_yaml
|
from ultralytics.yolo.utils.files import increment_path, save_yaml
|
||||||
from ultralytics.yolo.utils.modeling import get_model
|
from ultralytics.yolo.utils.modeling import get_model
|
||||||
|
|
||||||
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yml"
|
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
|
||||||
|
|
||||||
|
|
||||||
class BaseTrainer:
|
class BaseTrainer:
|
||||||
|
@ -47,7 +47,7 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
return torch.nn.functional.cross_entropy(preds, targets)
|
return torch.nn.functional.cross_entropy(preds, targets)
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.stem)
|
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
|
||||||
def train(cfg):
|
def train(cfg):
|
||||||
cfg.model = cfg.model or "resnet18"
|
cfg.model = cfg.model or "resnet18"
|
||||||
cfg.data = cfg.data or "imagenette160" # or yolo.ClassificationDataset("mnist")
|
cfg.data = cfg.data or "imagenette160" # or yolo.ClassificationDataset("mnist")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user