diff --git a/ultralytics/models/rtdetr/model.py b/ultralytics/models/rtdetr/model.py index b8def485..d1cb18df 100644 --- a/ultralytics/models/rtdetr/model.py +++ b/ultralytics/models/rtdetr/model.py @@ -6,6 +6,7 @@ hybrid encoder and IoU-aware query selection for enhanced detection accuracy. For more information on RT-DETR, visit: https://arxiv.org/pdf/2304.08069.pdf """ +from pathlib import Path from ultralytics.engine.model import Model from ultralytics.nn.tasks import RTDETRDetectionModel @@ -34,7 +35,7 @@ class RTDETR(Model): Raises: NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'. """ - if model and model.split(".")[-1] not in ("pt", "yaml", "yml"): + if model and Path(model).suffix not in (".pt", ".yaml", ".yml"): raise NotImplementedError("RT-DETR only supports creating from *.pt, *.yaml, or *.yml files.") super().__init__(model=model, task="detect")