mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-24 14:44:21 +08:00
Add CLI support for SAM, RTDETR (#3253)
This commit is contained in:
parent
4c2033d7c3
commit
58bccb1a9f
@ -17,6 +17,7 @@ class SAM:
|
|||||||
# Should raise AssertionError instead?
|
# Should raise AssertionError instead?
|
||||||
raise NotImplementedError('Segment anything prediction requires pre-trained checkpoint')
|
raise NotImplementedError('Segment anything prediction requires pre-trained checkpoint')
|
||||||
self.model = build_sam(model)
|
self.model = build_sam(model)
|
||||||
|
self.task = 'segment' # required
|
||||||
self.predictor = None # reuse predictor
|
self.predictor = None # reuse predictor
|
||||||
|
|
||||||
def predict(self, source, stream=False, **kwargs):
|
def predict(self, source, stream=False, **kwargs):
|
||||||
|
@ -366,9 +366,16 @@ def entrypoint(debug=''):
|
|||||||
if model is None:
|
if model is None:
|
||||||
model = 'yolov8n.pt'
|
model = 'yolov8n.pt'
|
||||||
LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.")
|
LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.")
|
||||||
from ultralytics.yolo.engine.model import YOLO
|
|
||||||
overrides['model'] = model
|
overrides['model'] = model
|
||||||
model = YOLO(model, task=task)
|
if 'rtdetr' in model.lower(): # guess architecture
|
||||||
|
from ultralytics import RTDETR
|
||||||
|
model = RTDETR(model) # no task argument
|
||||||
|
elif 'sam' in model.lower():
|
||||||
|
from ultralytics import SAM
|
||||||
|
model = SAM(model)
|
||||||
|
else:
|
||||||
|
from ultralytics import YOLO
|
||||||
|
model = YOLO(model, task=task)
|
||||||
if isinstance(overrides.get('pretrained'), str):
|
if isinstance(overrides.get('pretrained'), str):
|
||||||
model.load(overrides['pretrained'])
|
model.load(overrides['pretrained'])
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user