mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
Fix catastrophic accuracy degradation of TFLite static quantized integer models (#1695)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
9627323d7c
commit
3c787eb080
@ -593,14 +593,43 @@ class Exporter:
|
||||
f_onnx, _ = self.export_onnx()
|
||||
|
||||
# Export to TF
|
||||
int8 = '-oiqt -qt per-tensor' if self.args.int8 else ''
|
||||
cmd = f'onnx2tf -i "{f_onnx}" -o "{f}" -nuo --non_verbose {int8}'
|
||||
LOGGER.info(f"\n{prefix} running '{cmd}'")
|
||||
tmp_file = f / 'tmp_tflite_int8_calibration_images.npy' # int8 calibration images file
|
||||
if self.args.int8:
|
||||
if self.args.data:
|
||||
import numpy as np
|
||||
|
||||
from ultralytics.data.dataset import YOLODataset
|
||||
from ultralytics.data.utils import check_det_dataset
|
||||
|
||||
# Generate calibration data for integer quantization
|
||||
LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
|
||||
dataset = YOLODataset(check_det_dataset(self.args.data)['val'], imgsz=self.imgsz[0], augment=False)
|
||||
images = []
|
||||
n_images = 100 # maximum number of images
|
||||
for n, batch in enumerate(dataset):
|
||||
if n >= n_images:
|
||||
break
|
||||
im = batch['img'].permute(1, 2, 0)[None] # list to nparray, CHW to BHWC,
|
||||
images.append(im)
|
||||
f.mkdir()
|
||||
images = torch.cat(images, 0).float()
|
||||
# mean = images.view(-1, 3).mean(0) # imagenet mean [123.675, 116.28, 103.53]
|
||||
# std = images.view(-1, 3).std(0) # imagenet std [58.395, 57.12, 57.375]
|
||||
np.save(str(tmp_file), images.numpy()) # BHWC
|
||||
int8 = f'-oiqt -qt per-tensor -cind images "{tmp_file}" "[[[[0, 0, 0]]]]" "[[[[255, 255, 255]]]]"'
|
||||
else:
|
||||
int8 = '-oiqt -qt per-tensor'
|
||||
else:
|
||||
int8 = ''
|
||||
|
||||
cmd = f'onnx2tf -i "{f_onnx}" -o "{f}" -nuo --non_verbose {int8}'.strip()
|
||||
LOGGER.info(f"{prefix} running '{cmd}'")
|
||||
subprocess.run(cmd, shell=True)
|
||||
yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml
|
||||
|
||||
# Remove/rename TFLite models
|
||||
if self.args.int8:
|
||||
tmp_file.unlink(missing_ok=True)
|
||||
for file in f.rglob('*_dynamic_range_quant.tflite'):
|
||||
file.rename(file.with_name(file.stem.replace('_dynamic_range_quant', '_int8') + file.suffix))
|
||||
for file in f.rglob('*_integer_quant_with_int16_act.tflite'):
|
||||
|
@ -343,6 +343,8 @@ class YOLO:
|
||||
overrides['imgsz'] = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
||||
if 'batch' not in kwargs:
|
||||
overrides['batch'] = 1 # default to 1 if not modified
|
||||
if 'data' not in kwargs:
|
||||
overrides['data'] = None # default to None if not modified (avoid int8 calibration with coco.yaml)
|
||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||
args.task = self.task
|
||||
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
|
||||
|
@ -400,21 +400,21 @@ class AutoBackend(nn.Module):
|
||||
nc = y[ib].shape[1] - y[ip].shape[3] - 4 # y = (1, 160, 160, 32), (1, 116, 8400)
|
||||
self.names = {i: f'class{i}' for i in range(nc)}
|
||||
else: # Lite or Edge TPU
|
||||
input = self.input_details[0]
|
||||
int8 = input['dtype'] == np.int8 # is TFLite quantized int8 model
|
||||
if int8:
|
||||
scale, zero_point = input['quantization']
|
||||
im = (im / scale + zero_point).astype(np.int8) # de-scale
|
||||
self.interpreter.set_tensor(input['index'], im)
|
||||
details = self.input_details[0]
|
||||
integer = details['dtype'] in (np.int8, np.int16) # is TFLite quantized int8 or int16 model
|
||||
if integer:
|
||||
scale, zero_point = details['quantization']
|
||||
im = (im / scale + zero_point).astype(details['dtype']) # de-scale
|
||||
self.interpreter.set_tensor(details['index'], im)
|
||||
self.interpreter.invoke()
|
||||
y = []
|
||||
for output in self.output_details:
|
||||
x = self.interpreter.get_tensor(output['index'])
|
||||
if int8:
|
||||
if integer:
|
||||
scale, zero_point = output['quantization']
|
||||
x = (x.astype(np.float32) - zero_point) * scale # re-scale
|
||||
if x.ndim > 2: # if task is not classification
|
||||
# Unnormalize xywh with input image size
|
||||
# Denormalize xywh with input image size
|
||||
# xywh are normalized in TFLite/EdgeTPU to mitigate quantization error of integer models
|
||||
# See this PR for details: https://github.com/ultralytics/ultralytics/pull/1695
|
||||
x[:, 0] *= w
|
||||
|
Loading…
x
Reference in New Issue
Block a user