mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-24 06:14:55 +08:00
Add CLI support for SAM, RTDETR (#3253)
This commit is contained in:
parent
4c2033d7c3
commit
58bccb1a9f
@ -17,6 +17,7 @@ class SAM:
|
||||
# Should raise AssertionError instead?
|
||||
raise NotImplementedError('Segment anything prediction requires pre-trained checkpoint')
|
||||
self.model = build_sam(model)
|
||||
self.task = 'segment' # required
|
||||
self.predictor = None # reuse predictor
|
||||
|
||||
def predict(self, source, stream=False, **kwargs):
|
||||
|
@ -366,9 +366,16 @@ def entrypoint(debug=''):
|
||||
if model is None:
|
||||
model = 'yolov8n.pt'
|
||||
LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.")
|
||||
from ultralytics.yolo.engine.model import YOLO
|
||||
overrides['model'] = model
|
||||
model = YOLO(model, task=task)
|
||||
if 'rtdetr' in model.lower(): # guess architecture
|
||||
from ultralytics import RTDETR
|
||||
model = RTDETR(model) # no task argument
|
||||
elif 'sam' in model.lower():
|
||||
from ultralytics import SAM
|
||||
model = SAM(model)
|
||||
else:
|
||||
from ultralytics import YOLO
|
||||
model = YOLO(model, task=task)
|
||||
if isinstance(overrides.get('pretrained'), str):
|
||||
model.load(overrides['pretrained'])
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user