mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Fixed RTDETR GFLOPs bug (#7309)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
61fa12460d
commit
a334ff7f6e
@ -274,16 +274,26 @@ def model_info_for_loggers(trainer):
|
||||
|
||||
def get_flops(model, imgsz=640):
|
||||
"""Return a YOLO model's FLOPs."""
|
||||
if not thop:
|
||||
return 0.0 # if not installed return 0.0 GFLOPs
|
||||
|
||||
try:
|
||||
model = de_parallel(model)
|
||||
p = next(model.parameters())
|
||||
if not isinstance(imgsz, list):
|
||||
imgsz = [imgsz, imgsz] # expand if int/float
|
||||
try:
|
||||
# Use stride size for input tensor
|
||||
stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride
|
||||
im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
|
||||
flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1E9 * 2 if thop else 0 # stride GFLOPs
|
||||
imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
|
||||
return flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
|
||||
flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1E9 * 2 # stride GFLOPs
|
||||
return flops * imgsz[0] / stride * imgsz[1] / stride # imgsz GFLOPs
|
||||
except Exception:
|
||||
return 0
|
||||
# Use actual image size for input tensor (i.e. required for RTDETR models)
|
||||
im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format
|
||||
return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1E9 * 2 # imgsz GFLOPs
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
|
||||
def get_flops_with_torch_profiler(model, imgsz=640):
|
||||
|
Loading…
x
Reference in New Issue
Block a user