mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-28 09:44:22 +08:00
Update save_dir on new predict args (#3215)
This commit is contained in:
parent
431fad6834
commit
d69a1e8046
@ -250,6 +250,8 @@ class YOLO:
|
||||
self.predictor.setup_model(model=self.model, verbose=is_cli)
|
||||
else: # only update args if predictor is already setup
|
||||
self.predictor.args = get_cfg(self.predictor.args, overrides)
|
||||
if 'project' in overrides or 'name' in overrides:
|
||||
self.predictor.save_dir = self.predictor.get_save_dir()
|
||||
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
|
||||
|
||||
def track(self, source=None, stream=False, persist=False, **kwargs):
|
||||
|
@ -65,14 +65,13 @@ class BasePredictor:
|
||||
Attributes:
|
||||
args (SimpleNamespace): Configuration for the predictor.
|
||||
save_dir (Path): Directory to save results.
|
||||
done_setup (bool): Whether the predictor has finished setup.
|
||||
done_warmup (bool): Whether the predictor has finished setup.
|
||||
model (nn.Module): Model used for prediction.
|
||||
data (dict): Data configuration.
|
||||
device (torch.device): Device used for prediction.
|
||||
dataset (Dataset): Dataset used for prediction.
|
||||
vid_path (str): Path to video file.
|
||||
vid_writer (cv2.VideoWriter): Video writer for saving video output.
|
||||
annotator (Annotator): Annotator used for prediction.
|
||||
data_path (str): Path to data.
|
||||
"""
|
||||
|
||||
@ -85,9 +84,7 @@ class BasePredictor:
|
||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||
"""
|
||||
self.args = get_cfg(cfg, overrides)
|
||||
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
||||
name = self.args.name or f'{self.args.mode}'
|
||||
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
|
||||
self.save_dir = self.get_save_dir()
|
||||
if self.args.conf is None:
|
||||
self.args.conf = 0.25 # default conf=0.25
|
||||
self.done_warmup = False
|
||||
@ -108,6 +105,11 @@ class BasePredictor:
|
||||
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
def get_save_dir(self):
|
||||
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
||||
name = self.args.name or f'{self.args.mode}'
|
||||
return increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
|
||||
|
||||
def preprocess(self, im):
|
||||
"""Prepares input image before inference.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user