mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-24 06:14:55 +08:00
Fix per-layer FLOPs profiling (#4048)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
e6d18cc944
commit
379f31904a
@ -87,9 +87,8 @@ class BaseModel(nn.Module):
|
|||||||
|
|
||||||
def _predict_augment(self, x):
|
def _predict_augment(self, x):
|
||||||
"""Perform augmentations on input image x and return augmented inference."""
|
"""Perform augmentations on input image x and return augmented inference."""
|
||||||
LOGGER.warning(
|
LOGGER.warning(f'WARNING ⚠️ {self.__class__.__name__} does not support augmented inference yet. '
|
||||||
f'WARNING ⚠️ {self.__class__.__name__} has not supported augment inference yet! Now using single-scale inference instead.'
|
f'Reverting to single-scale inference instead.')
|
||||||
)
|
|
||||||
return self._predict_once(x)
|
return self._predict_once(x)
|
||||||
|
|
||||||
def _profile_one_layer(self, m, x, dt):
|
def _profile_one_layer(self, m, x, dt):
|
||||||
@ -105,15 +104,15 @@ class BaseModel(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
c = m == self.model[-1] # is final layer, copy input as inplace fix
|
c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix
|
||||||
o = thop.profile(m, inputs=[x.clone() if c else x], verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
|
flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
|
||||||
t = time_sync()
|
t = time_sync()
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
m(x.clone() if c else x)
|
m(x.copy() if c else x)
|
||||||
dt.append((time_sync() - t) * 100)
|
dt.append((time_sync() - t) * 100)
|
||||||
if m == self.model[0]:
|
if m == self.model[0]:
|
||||||
LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
|
LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
|
||||||
LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
|
LOGGER.info(f'{dt[-1]:10.2f} {flops:10.2f} {m.np:10.0f} {m.type}')
|
||||||
if c:
|
if c:
|
||||||
LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
|
LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
|
||||||
|
|
||||||
@ -338,7 +337,7 @@ class ClassificationModel(BaseModel):
|
|||||||
"""YOLOv8 classification model."""
|
"""YOLOv8 classification model."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
cfg=None,
|
cfg='yolov8n-cls.yaml',
|
||||||
model=None,
|
model=None,
|
||||||
ch=3,
|
ch=3,
|
||||||
nc=None,
|
nc=None,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user