mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-24 06:14:55 +08:00
Fix catastrophic accuracy degradation of TFLite static quantized integer models (#1695)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
9627323d7c
commit
3c787eb080
@ -593,14 +593,43 @@ class Exporter:
|
|||||||
f_onnx, _ = self.export_onnx()
|
f_onnx, _ = self.export_onnx()
|
||||||
|
|
||||||
# Export to TF
|
# Export to TF
|
||||||
int8 = '-oiqt -qt per-tensor' if self.args.int8 else ''
|
tmp_file = f / 'tmp_tflite_int8_calibration_images.npy' # int8 calibration images file
|
||||||
cmd = f'onnx2tf -i "{f_onnx}" -o "{f}" -nuo --non_verbose {int8}'
|
if self.args.int8:
|
||||||
LOGGER.info(f"\n{prefix} running '{cmd}'")
|
if self.args.data:
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ultralytics.data.dataset import YOLODataset
|
||||||
|
from ultralytics.data.utils import check_det_dataset
|
||||||
|
|
||||||
|
# Generate calibration data for integer quantization
|
||||||
|
LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
|
||||||
|
dataset = YOLODataset(check_det_dataset(self.args.data)['val'], imgsz=self.imgsz[0], augment=False)
|
||||||
|
images = []
|
||||||
|
n_images = 100 # maximum number of images
|
||||||
|
for n, batch in enumerate(dataset):
|
||||||
|
if n >= n_images:
|
||||||
|
break
|
||||||
|
im = batch['img'].permute(1, 2, 0)[None] # list to nparray, CHW to BHWC,
|
||||||
|
images.append(im)
|
||||||
|
f.mkdir()
|
||||||
|
images = torch.cat(images, 0).float()
|
||||||
|
# mean = images.view(-1, 3).mean(0) # imagenet mean [123.675, 116.28, 103.53]
|
||||||
|
# std = images.view(-1, 3).std(0) # imagenet std [58.395, 57.12, 57.375]
|
||||||
|
np.save(str(tmp_file), images.numpy()) # BHWC
|
||||||
|
int8 = f'-oiqt -qt per-tensor -cind images "{tmp_file}" "[[[[0, 0, 0]]]]" "[[[[255, 255, 255]]]]"'
|
||||||
|
else:
|
||||||
|
int8 = '-oiqt -qt per-tensor'
|
||||||
|
else:
|
||||||
|
int8 = ''
|
||||||
|
|
||||||
|
cmd = f'onnx2tf -i "{f_onnx}" -o "{f}" -nuo --non_verbose {int8}'.strip()
|
||||||
|
LOGGER.info(f"{prefix} running '{cmd}'")
|
||||||
subprocess.run(cmd, shell=True)
|
subprocess.run(cmd, shell=True)
|
||||||
yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml
|
yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml
|
||||||
|
|
||||||
# Remove/rename TFLite models
|
# Remove/rename TFLite models
|
||||||
if self.args.int8:
|
if self.args.int8:
|
||||||
|
tmp_file.unlink(missing_ok=True)
|
||||||
for file in f.rglob('*_dynamic_range_quant.tflite'):
|
for file in f.rglob('*_dynamic_range_quant.tflite'):
|
||||||
file.rename(file.with_name(file.stem.replace('_dynamic_range_quant', '_int8') + file.suffix))
|
file.rename(file.with_name(file.stem.replace('_dynamic_range_quant', '_int8') + file.suffix))
|
||||||
for file in f.rglob('*_integer_quant_with_int16_act.tflite'):
|
for file in f.rglob('*_integer_quant_with_int16_act.tflite'):
|
||||||
|
@ -343,6 +343,8 @@ class YOLO:
|
|||||||
overrides['imgsz'] = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
overrides['imgsz'] = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
||||||
if 'batch' not in kwargs:
|
if 'batch' not in kwargs:
|
||||||
overrides['batch'] = 1 # default to 1 if not modified
|
overrides['batch'] = 1 # default to 1 if not modified
|
||||||
|
if 'data' not in kwargs:
|
||||||
|
overrides['data'] = None # default to None if not modified (avoid int8 calibration with coco.yaml)
|
||||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||||
args.task = self.task
|
args.task = self.task
|
||||||
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
|
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
|
||||||
|
@ -400,21 +400,21 @@ class AutoBackend(nn.Module):
|
|||||||
nc = y[ib].shape[1] - y[ip].shape[3] - 4 # y = (1, 160, 160, 32), (1, 116, 8400)
|
nc = y[ib].shape[1] - y[ip].shape[3] - 4 # y = (1, 160, 160, 32), (1, 116, 8400)
|
||||||
self.names = {i: f'class{i}' for i in range(nc)}
|
self.names = {i: f'class{i}' for i in range(nc)}
|
||||||
else: # Lite or Edge TPU
|
else: # Lite or Edge TPU
|
||||||
input = self.input_details[0]
|
details = self.input_details[0]
|
||||||
int8 = input['dtype'] == np.int8 # is TFLite quantized int8 model
|
integer = details['dtype'] in (np.int8, np.int16) # is TFLite quantized int8 or int16 model
|
||||||
if int8:
|
if integer:
|
||||||
scale, zero_point = input['quantization']
|
scale, zero_point = details['quantization']
|
||||||
im = (im / scale + zero_point).astype(np.int8) # de-scale
|
im = (im / scale + zero_point).astype(details['dtype']) # de-scale
|
||||||
self.interpreter.set_tensor(input['index'], im)
|
self.interpreter.set_tensor(details['index'], im)
|
||||||
self.interpreter.invoke()
|
self.interpreter.invoke()
|
||||||
y = []
|
y = []
|
||||||
for output in self.output_details:
|
for output in self.output_details:
|
||||||
x = self.interpreter.get_tensor(output['index'])
|
x = self.interpreter.get_tensor(output['index'])
|
||||||
if int8:
|
if integer:
|
||||||
scale, zero_point = output['quantization']
|
scale, zero_point = output['quantization']
|
||||||
x = (x.astype(np.float32) - zero_point) * scale # re-scale
|
x = (x.astype(np.float32) - zero_point) * scale # re-scale
|
||||||
if x.ndim > 2: # if task is not classification
|
if x.ndim > 2: # if task is not classification
|
||||||
# Unnormalize xywh with input image size
|
# Denormalize xywh with input image size
|
||||||
# xywh are normalized in TFLite/EdgeTPU to mitigate quantization error of integer models
|
# xywh are normalized in TFLite/EdgeTPU to mitigate quantization error of integer models
|
||||||
# See this PR for details: https://github.com/ultralytics/ultralytics/pull/1695
|
# See this PR for details: https://github.com/ultralytics/ultralytics/pull/1695
|
||||||
x[:, 0] *= w
|
x[:, 0] *= w
|
||||||
|
Loading…
x
Reference in New Issue
Block a user