mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-06-13 12:14:22 +08:00
Cleanup argument handling in Model
class (#4614)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sebastian Stapf <42514241+Wiqzard@users.noreply.github.com>
This commit is contained in:
parent
53b4f8c713
commit
7e99804263
docs/reference/cfg
tests
ultralytics
@ -17,6 +17,10 @@ keywords: Ultralytics, YOLO, Configuration, cfg2dict, handle_deprecation, merge_
|
|||||||
## ::: ultralytics.cfg.get_cfg
|
## ::: ultralytics.cfg.get_cfg
|
||||||
<br><br>
|
<br><br>
|
||||||
|
|
||||||
|
---
|
||||||
|
## ::: ultralytics.cfg.get_save_dir
|
||||||
|
<br><br>
|
||||||
|
|
||||||
---
|
---
|
||||||
## ::: ultralytics.cfg._handle_deprecation
|
## ::: ultralytics.cfg._handle_deprecation
|
||||||
<br><br>
|
<br><br>
|
||||||
|
@ -182,9 +182,11 @@ def test_export_openvino():
|
|||||||
|
|
||||||
def test_export_coreml():
|
def test_export_coreml():
|
||||||
if not WINDOWS: # RuntimeError: BlobWriter not loaded with coremltools 7.0 on windows
|
if not WINDOWS: # RuntimeError: BlobWriter not loaded with coremltools 7.0 on windows
|
||||||
f = YOLO(MODEL).export(format='coreml', nms=True)
|
|
||||||
if MACOS:
|
if MACOS:
|
||||||
YOLO(f)(SOURCE) # model prediction only supported on macOS
|
f = YOLO(MODEL).export(format='coreml')
|
||||||
|
YOLO(f)(SOURCE) # model prediction only supported on macOS for nms=False models
|
||||||
|
else:
|
||||||
|
YOLO(MODEL).export(format='coreml', nms=True)
|
||||||
|
|
||||||
|
|
||||||
def test_export_tflite(enabled=False):
|
def test_export_tflite(enabled=False):
|
||||||
|
@ -5,12 +5,10 @@ import sys
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from ultralytics.cfg import get_cfg, get_save_dir
|
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
|
||||||
from ultralytics.engine.exporter import Exporter
|
|
||||||
from ultralytics.hub.utils import HUB_WEB_ROOT
|
from ultralytics.hub.utils import HUB_WEB_ROOT
|
||||||
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
|
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
|
||||||
from ultralytics.utils import (ASSETS, DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, callbacks, emojis,
|
from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, callbacks, emojis, yaml_load
|
||||||
yaml_load)
|
|
||||||
from ultralytics.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml
|
from ultralytics.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml
|
||||||
from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS
|
from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS
|
||||||
from ultralytics.utils.torch_utils import smart_inference_mode
|
from ultralytics.utils.torch_utils import smart_inference_mode
|
||||||
@ -118,9 +116,9 @@ class Model:
|
|||||||
cfg_dict = yaml_model_load(cfg)
|
cfg_dict = yaml_model_load(cfg)
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.task = task or guess_model_task(cfg_dict)
|
self.task = task or guess_model_task(cfg_dict)
|
||||||
model = model or self.smart_load('model')
|
self.model = (model or self.smart_load('model'))(cfg_dict, verbose=verbose and RANK == -1) # build model
|
||||||
self.model = model(cfg_dict, verbose=verbose and RANK == -1) # build model
|
|
||||||
self.overrides['model'] = self.cfg
|
self.overrides['model'] = self.cfg
|
||||||
|
self.overrides['task'] = self.task
|
||||||
|
|
||||||
# Below added to allow export from YAMLs
|
# Below added to allow export from YAMLs
|
||||||
args = {**DEFAULT_CFG_DICT, **self.overrides} # combine model and default args, preferring model args
|
args = {**DEFAULT_CFG_DICT, **self.overrides} # combine model and default args, preferring model args
|
||||||
@ -220,28 +218,22 @@ class Model:
|
|||||||
if source is None:
|
if source is None:
|
||||||
source = ASSETS
|
source = ASSETS
|
||||||
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
|
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
|
||||||
|
|
||||||
is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and any(
|
is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and any(
|
||||||
x in sys.argv for x in ('predict', 'track', 'mode=predict', 'mode=track'))
|
x in sys.argv for x in ('predict', 'track', 'mode=predict', 'mode=track'))
|
||||||
# Check prompts for SAM/FastSAM
|
|
||||||
prompts = kwargs.pop('prompts', None)
|
custom = {'conf': 0.25, 'save': is_cli} # method defaults
|
||||||
overrides = self.overrides.copy()
|
args = {**self.overrides, **custom, **kwargs, 'mode': 'predict'} # highest priority args on the right
|
||||||
overrides['conf'] = 0.25
|
prompts = args.pop('prompts', None) # for SAM-type models
|
||||||
overrides.update(kwargs) # prefer kwargs
|
|
||||||
overrides['mode'] = kwargs.get('mode', 'predict')
|
|
||||||
assert overrides['mode'] in ['track', 'predict']
|
|
||||||
if not is_cli:
|
|
||||||
overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
|
|
||||||
if not self.predictor:
|
if not self.predictor:
|
||||||
self.task = overrides.get('task') or self.task
|
self.predictor = (predictor or self.smart_load('predictor'))(overrides=args, _callbacks=self.callbacks)
|
||||||
predictor = predictor or self.smart_load('predictor')
|
|
||||||
self.predictor = predictor(overrides=overrides, _callbacks=self.callbacks)
|
|
||||||
self.predictor.setup_model(model=self.model, verbose=is_cli)
|
self.predictor.setup_model(model=self.model, verbose=is_cli)
|
||||||
else: # only update args if predictor is already setup
|
else: # only update args if predictor is already setup
|
||||||
self.predictor.args = get_cfg(self.predictor.args, overrides)
|
self.predictor.args = get_cfg(self.predictor.args, args)
|
||||||
if 'project' in overrides or 'name' in overrides:
|
if 'project' in args or 'name' in args:
|
||||||
self.predictor.save_dir = get_save_dir(self.predictor.args)
|
self.predictor.save_dir = get_save_dir(self.predictor.args)
|
||||||
# Set prompts for SAM/FastSAM
|
if prompts and hasattr(self.predictor, 'set_prompts'): # for SAM-type models
|
||||||
if len and hasattr(self.predictor, 'set_prompts'):
|
|
||||||
self.predictor.set_prompts(prompts)
|
self.predictor.set_prompts(prompts)
|
||||||
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
|
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
|
||||||
|
|
||||||
@ -257,46 +249,31 @@ class Model:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(List[ultralytics.engine.results.Results]): The tracking results.
|
(List[ultralytics.engine.results.Results]): The tracking results.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not hasattr(self.predictor, 'trackers'):
|
if not hasattr(self.predictor, 'trackers'):
|
||||||
from ultralytics.trackers import register_tracker
|
from ultralytics.trackers import register_tracker
|
||||||
register_tracker(self, persist)
|
register_tracker(self, persist)
|
||||||
# ByteTrack-based method needs low confidence predictions as input
|
# ByteTrack-based method needs low confidence predictions as input
|
||||||
conf = kwargs.get('conf') or 0.1
|
kwargs['conf'] = kwargs.get('conf') or 0.1
|
||||||
kwargs['conf'] = conf
|
|
||||||
kwargs['mode'] = 'track'
|
kwargs['mode'] = 'track'
|
||||||
return self.predict(source=source, stream=stream, **kwargs)
|
return self.predict(source=source, stream=stream, **kwargs)
|
||||||
|
|
||||||
@smart_inference_mode()
|
@smart_inference_mode()
|
||||||
def val(self, data=None, validator=None, **kwargs):
|
def val(self, validator=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Validate a model on a given dataset.
|
Validate a model on a given dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data (str): The dataset to validate on. Accepts all formats accepted by yolo
|
|
||||||
validator (BaseValidator): Customized validator.
|
validator (BaseValidator): Customized validator.
|
||||||
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
|
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
|
||||||
"""
|
"""
|
||||||
overrides = self.overrides.copy()
|
custom = {'rect': True} # method defaults
|
||||||
overrides['rect'] = True # rect batches as default
|
args = {**self.overrides, **custom, **kwargs, 'mode': 'val'} # highest priority args on the right
|
||||||
overrides.update(kwargs)
|
args['imgsz'] = check_imgsz(args['imgsz'], max_dim=1)
|
||||||
overrides['mode'] = 'val'
|
|
||||||
if overrides.get('imgsz') is None:
|
|
||||||
overrides['imgsz'] = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
|
||||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
|
||||||
args.data = data or args.data
|
|
||||||
if 'task' in overrides:
|
|
||||||
self.task = args.task
|
|
||||||
else:
|
|
||||||
args.task = self.task
|
|
||||||
validator = validator or self.smart_load('validator')
|
|
||||||
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
|
|
||||||
|
|
||||||
validator = validator(args=args, _callbacks=self.callbacks)
|
validator = (validator or self.smart_load('validator'))(args=args, _callbacks=self.callbacks)
|
||||||
validator(model=self.model)
|
validator(model=self.model)
|
||||||
self.metrics = validator.metrics
|
self.metrics = validator.metrics
|
||||||
|
|
||||||
return validator.metrics
|
return validator.metrics
|
||||||
|
|
||||||
@smart_inference_mode()
|
@smart_inference_mode()
|
||||||
@ -309,17 +286,16 @@ class Model:
|
|||||||
"""
|
"""
|
||||||
self._check_is_pytorch_model()
|
self._check_is_pytorch_model()
|
||||||
from ultralytics.utils.benchmarks import benchmark
|
from ultralytics.utils.benchmarks import benchmark
|
||||||
overrides = self.model.args.copy()
|
|
||||||
overrides.update(kwargs)
|
custom = {'verbose': False} # method defaults
|
||||||
overrides['mode'] = 'benchmark'
|
args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, 'mode': 'benchmark'}
|
||||||
overrides = {**DEFAULT_CFG_DICT, **overrides} # fill in missing overrides keys with defaults
|
|
||||||
return benchmark(
|
return benchmark(
|
||||||
model=self,
|
model=self,
|
||||||
data=kwargs.get('data'), # if no 'data' argument passed set data=None for default datasets
|
data=kwargs.get('data'), # if no 'data' argument passed set data=None for default datasets
|
||||||
imgsz=overrides['imgsz'],
|
imgsz=args['imgsz'],
|
||||||
half=overrides['half'],
|
half=args['half'],
|
||||||
int8=overrides['int8'],
|
int8=args['int8'],
|
||||||
device=overrides['device'],
|
device=args['device'],
|
||||||
verbose=kwargs.get('verbose'))
|
verbose=kwargs.get('verbose'))
|
||||||
|
|
||||||
def export(self, **kwargs):
|
def export(self, **kwargs):
|
||||||
@ -327,22 +303,13 @@ class Model:
|
|||||||
Export model.
|
Export model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
|
**kwargs : Any other args accepted by the Exporter. To see all args check 'configuration' section in docs.
|
||||||
"""
|
"""
|
||||||
self._check_is_pytorch_model()
|
self._check_is_pytorch_model()
|
||||||
overrides = self.overrides.copy()
|
from .exporter import Exporter
|
||||||
overrides.update(kwargs)
|
|
||||||
overrides['mode'] = 'export'
|
custom = {'imgsz': self.model.args['imgsz'], 'batch': 1, 'data': None, 'verbose': False} # method defaults
|
||||||
if overrides.get('imgsz') is None:
|
args = {**self.overrides, **custom, **kwargs, 'mode': 'export'} # highest priority args on the right
|
||||||
overrides['imgsz'] = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
|
||||||
if 'batch' not in kwargs:
|
|
||||||
overrides['batch'] = 1 # default to 1 if not modified
|
|
||||||
if 'data' not in kwargs:
|
|
||||||
overrides['data'] = None # default to None if not modified (avoid int8 calibration with coco.yaml)
|
|
||||||
if 'verbose' not in kwargs:
|
|
||||||
overrides['verbose'] = False
|
|
||||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
|
||||||
args.task = self.task
|
|
||||||
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
|
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
|
||||||
|
|
||||||
def train(self, trainer=None, **kwargs):
|
def train(self, trainer=None, **kwargs):
|
||||||
@ -359,20 +326,15 @@ class Model:
|
|||||||
LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.')
|
LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.')
|
||||||
kwargs = self.session.train_args
|
kwargs = self.session.train_args
|
||||||
check_pip_update_available()
|
check_pip_update_available()
|
||||||
overrides = self.overrides.copy()
|
|
||||||
if kwargs.get('cfg'):
|
overrides = yaml_load(check_yaml(kwargs['cfg'])) if kwargs.get('cfg') else self.overrides
|
||||||
LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
|
custom = {'data': TASK2DATA[self.task]} # method defaults
|
||||||
overrides = yaml_load(check_yaml(kwargs['cfg']))
|
args = {**overrides, **custom, **kwargs, 'mode': 'train'} # highest priority args on the right
|
||||||
overrides.update(kwargs)
|
if args.get('resume'):
|
||||||
overrides['mode'] = 'train'
|
args['resume'] = self.ckpt_path
|
||||||
if not overrides.get('data'):
|
|
||||||
raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
|
self.trainer = (trainer or self.smart_load('trainer'))(overrides=args, _callbacks=self.callbacks)
|
||||||
if overrides.get('resume'):
|
if not args.get('resume'): # manually set model only if not resuming
|
||||||
overrides['resume'] = self.ckpt_path
|
|
||||||
self.task = overrides.get('task') or self.task
|
|
||||||
trainer = trainer or self.smart_load('trainer')
|
|
||||||
self.trainer = trainer(overrides=overrides, _callbacks=self.callbacks)
|
|
||||||
if not overrides.get('resume'): # manually set model only if not resuming
|
|
||||||
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
||||||
self.model = self.trainer.model
|
self.model = self.trainer.model
|
||||||
self.trainer.hub_session = self.session # attach optional HUB session
|
self.trainer.hub_session = self.session # attach optional HUB session
|
||||||
@ -455,7 +417,7 @@ class Model:
|
|||||||
name = self.__class__.__name__
|
name = self.__class__.__name__
|
||||||
mode = inspect.stack()[1][3] # get the function name.
|
mode = inspect.stack()[1][3] # get the function name.
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
emojis(f'WARNING ⚠️ `{name}` model does not support `{mode}` mode for `{self.task}` task yet.')) from e
|
emojis(f"WARNING ⚠️ '{name}' model does not support '{mode}' mode for '{self.task}' task yet.")) from e
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def task_map(self):
|
def task_map(self):
|
||||||
|
@ -175,8 +175,13 @@ class BaseTrainer:
|
|||||||
if world_size > 1 and 'LOCAL_RANK' not in os.environ:
|
if world_size > 1 and 'LOCAL_RANK' not in os.environ:
|
||||||
# Argument checks
|
# Argument checks
|
||||||
if self.args.rect:
|
if self.args.rect:
|
||||||
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting rect=False")
|
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
|
||||||
self.args.rect = False
|
self.args.rect = False
|
||||||
|
if self.args.batch == -1:
|
||||||
|
LOGGER.warning("WARNING ⚠️ 'batch=-1' for AutoBatch is incompatible with Multi-GPU training, setting "
|
||||||
|
"default 'batch=16'")
|
||||||
|
self.args.batch = 16
|
||||||
|
|
||||||
# Command
|
# Command
|
||||||
cmd, file = generate_ddp_command(world_size, self)
|
cmd, file = generate_ddp_command(world_size, self)
|
||||||
try:
|
try:
|
||||||
@ -186,6 +191,7 @@ class BaseTrainer:
|
|||||||
raise e
|
raise e
|
||||||
finally:
|
finally:
|
||||||
ddp_cleanup(self, str(file))
|
ddp_cleanup(self, str(file))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self._do_train(world_size)
|
self._do_train(world_size)
|
||||||
|
|
||||||
@ -248,9 +254,6 @@ class BaseTrainer:
|
|||||||
if self.batch_size == -1:
|
if self.batch_size == -1:
|
||||||
if RANK == -1: # single-GPU only, estimate best batch size
|
if RANK == -1: # single-GPU only, estimate best batch size
|
||||||
self.args.batch = self.batch_size = check_train_batch_size(self.model, self.args.imgsz, self.amp)
|
self.args.batch = self.batch_size = check_train_batch_size(self.model, self.args.imgsz, self.amp)
|
||||||
else:
|
|
||||||
SyntaxError('batch=-1 to use AutoBatch is only available in Single-GPU training. '
|
|
||||||
'Please pass a valid batch size value for Multi-GPU DDP training, i.e. batch=16')
|
|
||||||
|
|
||||||
# Dataloaders
|
# Dataloaders
|
||||||
batch_size = self.batch_size // max(world_size, 1)
|
batch_size = self.batch_size // max(world_size, 1)
|
||||||
|
@ -18,7 +18,6 @@ from PIL import Image
|
|||||||
from ultralytics.utils import ARM64, LINUX, LOGGER, ROOT, yaml_load
|
from ultralytics.utils import ARM64, LINUX, LOGGER, ROOT, yaml_load
|
||||||
from ultralytics.utils.checks import check_requirements, check_suffix, check_version, check_yaml
|
from ultralytics.utils.checks import check_requirements, check_suffix, check_version, check_yaml
|
||||||
from ultralytics.utils.downloads import attempt_download_asset, is_url
|
from ultralytics.utils.downloads import attempt_download_asset, is_url
|
||||||
from ultralytics.utils.ops import xywh2xyxy
|
|
||||||
|
|
||||||
|
|
||||||
def check_class_names(names):
|
def check_class_names(names):
|
||||||
@ -363,9 +362,13 @@ class AutoBackend(nn.Module):
|
|||||||
# im = im.resize((192, 320), Image.BILINEAR)
|
# im = im.resize((192, 320), Image.BILINEAR)
|
||||||
y = self.model.predict({'image': im_pil}) # coordinates are xywh normalized
|
y = self.model.predict({'image': im_pil}) # coordinates are xywh normalized
|
||||||
if 'confidence' in y:
|
if 'confidence' in y:
|
||||||
box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
|
raise TypeError('Ultralytics only supports inference of non-pipelined CoreML models exported with '
|
||||||
conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
|
f"'nms=False', but 'model={w}' has an NMS pipeline created by an 'nms=True' export.")
|
||||||
y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
|
# TODO: CoreML NMS inference handling
|
||||||
|
# from ultralytics.utils.ops import xywh2xyxy
|
||||||
|
# box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
|
||||||
|
# conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float32)
|
||||||
|
# y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
|
||||||
elif len(y) == 1: # classification model
|
elif len(y) == 1: # classification model
|
||||||
y = list(y.values())
|
y = list(y.values())
|
||||||
elif len(y) == 2: # segmentation model
|
elif len(y) == 2: # segmentation model
|
||||||
|
@ -355,15 +355,15 @@ class DeformableTransformerDecoder(nn.Module):
|
|||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
output = layer(output, refer_bbox, feats, shapes, padding_mask, attn_mask, pos_mlp(refer_bbox))
|
output = layer(output, refer_bbox, feats, shapes, padding_mask, attn_mask, pos_mlp(refer_bbox))
|
||||||
|
|
||||||
# refine bboxes, (bs, num_queries+num_denoising, 4)
|
bbox = bbox_head[i](output)
|
||||||
refined_bbox = torch.sigmoid(bbox_head[i](output) + inverse_sigmoid(refer_bbox))
|
refined_bbox = torch.sigmoid(bbox + inverse_sigmoid(refer_bbox))
|
||||||
|
|
||||||
if self.training:
|
if self.training:
|
||||||
dec_cls.append(score_head[i](output))
|
dec_cls.append(score_head[i](output))
|
||||||
if i == 0:
|
if i == 0:
|
||||||
dec_bboxes.append(refined_bbox)
|
dec_bboxes.append(refined_bbox)
|
||||||
else:
|
else:
|
||||||
dec_bboxes.append(torch.sigmoid(bbox_head[i](output) + inverse_sigmoid(last_refined_bbox)))
|
dec_bboxes.append(torch.sigmoid(bbox + inverse_sigmoid(last_refined_bbox)))
|
||||||
elif i == self.eval_idx:
|
elif i == self.eval_idx:
|
||||||
dec_cls.append(score_head[i](output))
|
dec_cls.append(score_head[i](output))
|
||||||
dec_bboxes.append(refined_bbox)
|
dec_bboxes.append(refined_bbox)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user