mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
ORT_CPP add CUDA FP16 inference (#4320)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
02d4f5200d
commit
1c753cbce6
@ -16,6 +16,10 @@ find_package(OpenCV REQUIRED)
|
|||||||
include_directories(${OpenCV_INCLUDE_DIRS})
|
include_directories(${OpenCV_INCLUDE_DIRS})
|
||||||
|
|
||||||
|
|
||||||
|
# -------------- Compile CUDA for FP16 inference if needed ------------------#
|
||||||
|
find_package(CUDA REQUIRED)
|
||||||
|
include_directories(${CUDA_INCLUDE_DIRS})
|
||||||
|
|
||||||
|
|
||||||
# ONNXRUNTIME
|
# ONNXRUNTIME
|
||||||
|
|
||||||
@ -51,9 +55,9 @@ set(PROJECT_SOURCES
|
|||||||
add_executable(${PROJECT_NAME} ${PROJECT_SOURCES})
|
add_executable(${PROJECT_NAME} ${PROJECT_SOURCES})
|
||||||
|
|
||||||
if(WIN32)
|
if(WIN32)
|
||||||
target_link_libraries(${PROJECT_NAME} ${OpenCV_LIBS} ${ONNXRUNTIME_ROOT}/lib/onnxruntime.lib)
|
target_link_libraries(${PROJECT_NAME} ${OpenCV_LIBS} ${ONNXRUNTIME_ROOT}/lib/onnxruntime.lib ${CUDA_LIBRARIES})
|
||||||
elseif(LINUX)
|
elseif(LINUX)
|
||||||
target_link_libraries(${PROJECT_NAME} ${OpenCV_LIBS} ${ONNXRUNTIME_ROOT}/lib/libonnxruntime.so)
|
target_link_libraries(${PROJECT_NAME} ${OpenCV_LIBS} ${ONNXRUNTIME_ROOT}/lib/libonnxruntime.so ${CUDA_LIBRARIES})
|
||||||
elseif(APPLE)
|
elseif(APPLE)
|
||||||
target_link_libraries(${PROJECT_NAME} ${OpenCV_LIBS} ${ONNXRUNTIME_ROOT}/lib/libonnxruntime.dylib)
|
target_link_libraries(${PROJECT_NAME} ${OpenCV_LIBS} ${ONNXRUNTIME_ROOT}/lib/libonnxruntime.dylib)
|
||||||
endif()
|
endif()
|
||||||
|
@ -6,8 +6,7 @@ This example demonstrates how to perform inference using YOLOv8 in C++ with ONNX
|
|||||||
|
|
||||||
- Friendly for deployment in the industrial sector.
|
- Friendly for deployment in the industrial sector.
|
||||||
- Faster than OpenCV's DNN inference on both CPU and GPU.
|
- Faster than OpenCV's DNN inference on both CPU and GPU.
|
||||||
- Supports CUDA acceleration.
|
- Supports FP32 and FP16 CUDA acceleration.
|
||||||
- Easy to add FP16 inference (using template functions).
|
|
||||||
|
|
||||||
## Exporting YOLOv8 Models
|
## Exporting YOLOv8 Models
|
||||||
|
|
||||||
@ -47,13 +46,12 @@ Note: The dependency on C++17 is due to the usage of the C++17 filesystem featur
|
|||||||
DCSP_INIT_PARAM params{ model_path, YOLO_ORIGIN_V8, {imgsz_w, imgsz_h}, 0.1, 0.5, false};
|
DCSP_INIT_PARAM params{ model_path, YOLO_ORIGIN_V8, {imgsz_w, imgsz_h}, 0.1, 0.5, false};
|
||||||
// GPU inference
|
// GPU inference
|
||||||
DCSP_INIT_PARAM params{ model_path, YOLO_ORIGIN_V8, {imgsz_w, imgsz_h}, 0.1, 0.5, true};
|
DCSP_INIT_PARAM params{ model_path, YOLO_ORIGIN_V8, {imgsz_w, imgsz_h}, 0.1, 0.5, true};
|
||||||
|
|
||||||
// Load your image
|
// Load your image
|
||||||
cv::Mat img = cv::imread(img_path);
|
cv::Mat img = cv::imread(img_path);
|
||||||
|
// Init Inference Session
|
||||||
|
char* ret = yoloDetector->CreateSession(params);
|
||||||
|
|
||||||
char* ret = p1->CreateSession(params);
|
ret = yoloDetector->RunSession(img, res);
|
||||||
|
|
||||||
ret = p->RunSession(img, res);
|
|
||||||
```
|
```
|
||||||
|
|
||||||
This repository should also work for YOLOv5, which needs a permute operator for the output of the YOLOv5 model, but this has not been implemented yet.
|
This repository should also work for YOLOv5, which needs a permute operator for the output of the YOLOv5 model, but this has not been implemented yet.
|
||||||
|
@ -15,6 +15,13 @@ DCSP_CORE::~DCSP_CORE()
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
namespace Ort
|
||||||
|
{
|
||||||
|
template<>
|
||||||
|
struct TypeToTensorType<half> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; };
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
char* BlobFromImage(cv::Mat& iImg, T& iBlob)
|
char* BlobFromImage(cv::Mat& iImg, T& iBlob)
|
||||||
{
|
{
|
||||||
@ -56,7 +63,7 @@ char* DCSP_CORE::CreateSession(DCSP_INIT_PARAM &iParams)
|
|||||||
bool result = std::regex_search(iParams.ModelPath, pattern);
|
bool result = std::regex_search(iParams.ModelPath, pattern);
|
||||||
if (result)
|
if (result)
|
||||||
{
|
{
|
||||||
Ret = "[DCSP_ONNX]:model path error.change your model path without chinese characters.";
|
Ret = "[DCSP_ONNX]:Model path error.Change your model path without chinese characters.";
|
||||||
std::cout << Ret << std::endl;
|
std::cout << Ret << std::endl;
|
||||||
return Ret;
|
return Ret;
|
||||||
}
|
}
|
||||||
@ -109,9 +116,7 @@ char* DCSP_CORE::CreateSession(DCSP_INIT_PARAM &iParams)
|
|||||||
}
|
}
|
||||||
options = Ort::RunOptions{ nullptr };
|
options = Ort::RunOptions{ nullptr };
|
||||||
WarmUpSession();
|
WarmUpSession();
|
||||||
//std::cout << OrtGetApiBase()->GetVersionString() << std::endl;;
|
return RET_OK;
|
||||||
Ret = RET_OK;
|
|
||||||
return Ret;
|
|
||||||
}
|
}
|
||||||
catch (const std::exception& e)
|
catch (const std::exception& e)
|
||||||
{
|
{
|
||||||
@ -122,7 +127,6 @@ char* DCSP_CORE::CreateSession(DCSP_INIT_PARAM &iParams)
|
|||||||
std::strcpy(merged, result.c_str());
|
std::strcpy(merged, result.c_str());
|
||||||
std::cout << merged << std::endl;
|
std::cout << merged << std::endl;
|
||||||
delete[] merged;
|
delete[] merged;
|
||||||
//return merged;
|
|
||||||
return "[DCSP_ONNX]:Create session failed.";
|
return "[DCSP_ONNX]:Create session failed.";
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -145,6 +149,13 @@ char* DCSP_CORE::RunSession(cv::Mat &iImg, std::vector<DCSP_RESULT>& oResult)
|
|||||||
std::vector<int64_t> inputNodeDims = { 1,3,imgSize.at(0),imgSize.at(1) };
|
std::vector<int64_t> inputNodeDims = { 1,3,imgSize.at(0),imgSize.at(1) };
|
||||||
TensorProcess(starttime_1, iImg, blob, inputNodeDims, oResult);
|
TensorProcess(starttime_1, iImg, blob, inputNodeDims, oResult);
|
||||||
}
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
half* blob = new half[processedImg.total() * 3];
|
||||||
|
BlobFromImage(processedImg, blob);
|
||||||
|
std::vector<int64_t> inputNodeDims = { 1,3,imgSize.at(0),imgSize.at(1) };
|
||||||
|
TensorProcess(starttime_1, iImg, blob, inputNodeDims, oResult);
|
||||||
|
}
|
||||||
|
|
||||||
return Ret;
|
return Ret;
|
||||||
}
|
}
|
||||||
@ -169,7 +180,8 @@ char* DCSP_CORE::TensorProcess(clock_t& starttime_1, cv::Mat& iImg, N& blob, std
|
|||||||
delete blob;
|
delete blob;
|
||||||
switch (modelType)
|
switch (modelType)
|
||||||
{
|
{
|
||||||
case 1:
|
case 1://V8_ORIGIN_FP32
|
||||||
|
case 4://V8_ORIGIN_FP16
|
||||||
{
|
{
|
||||||
int strideNum = outputNodeDims[2];
|
int strideNum = outputNodeDims[2];
|
||||||
int signalResultNum = outputNodeDims[1];
|
int signalResultNum = outputNodeDims[1];
|
||||||
@ -243,15 +255,13 @@ char* DCSP_CORE::TensorProcess(clock_t& starttime_1, cv::Mat& iImg, N& blob, std
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
char* Ret = RET_OK;
|
return RET_OK;
|
||||||
return Ret;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
char* DCSP_CORE::WarmUpSession()
|
char* DCSP_CORE::WarmUpSession()
|
||||||
{
|
{
|
||||||
clock_t starttime_1 = clock();
|
clock_t starttime_1 = clock();
|
||||||
char* Ret = RET_OK;
|
|
||||||
cv::Mat iImg = cv::Mat(cv::Size(imgSize.at(0), imgSize.at(1)), CV_8UC3);
|
cv::Mat iImg = cv::Mat(cv::Size(imgSize.at(0), imgSize.at(1)), CV_8UC3);
|
||||||
cv::Mat processedImg;
|
cv::Mat processedImg;
|
||||||
PostProcess(iImg, imgSize, processedImg);
|
PostProcess(iImg, imgSize, processedImg);
|
||||||
@ -270,5 +280,20 @@ char* DCSP_CORE::WarmUpSession()
|
|||||||
std::cout << "[DCSP_ONNX(CUDA)]: " << "Cuda warm-up cost " << post_process_time << " ms. " << std::endl;
|
std::cout << "[DCSP_ONNX(CUDA)]: " << "Cuda warm-up cost " << post_process_time << " ms. " << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Ret;
|
else
|
||||||
|
{
|
||||||
|
half* blob = new half[iImg.total() * 3];
|
||||||
|
BlobFromImage(processedImg, blob);
|
||||||
|
std::vector<int64_t> YOLO_input_node_dims = { 1,3,imgSize.at(0),imgSize.at(1) };
|
||||||
|
Ort::Value input_tensor = Ort::Value::CreateTensor<half>(Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU), blob, 3 * imgSize.at(0) * imgSize.at(1), YOLO_input_node_dims.data(), YOLO_input_node_dims.size());
|
||||||
|
auto output_tensors = session->Run(options, inputNodeNames.data(), &input_tensor, 1, outputNodeNames.data(), outputNodeNames.size());
|
||||||
|
delete[] blob;
|
||||||
|
clock_t starttime_4 = clock();
|
||||||
|
double post_process_time = (double)(starttime_4 - starttime_1) / CLOCKS_PER_SEC * 1000;
|
||||||
|
if (cudaEnable)
|
||||||
|
{
|
||||||
|
std::cout << "[DCSP_ONNX(CUDA)]: " << "Cuda warm-up cost " << post_process_time << " ms. " << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
@ -13,6 +13,7 @@
|
|||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <opencv2/opencv.hpp>
|
#include <opencv2/opencv.hpp>
|
||||||
#include "onnxruntime_cxx_api.h"
|
#include "onnxruntime_cxx_api.h"
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
|
|
||||||
enum MODEL_TYPE
|
enum MODEL_TYPE
|
||||||
@ -21,7 +22,10 @@ enum MODEL_TYPE
|
|||||||
YOLO_ORIGIN_V5 = 0,
|
YOLO_ORIGIN_V5 = 0,
|
||||||
YOLO_ORIGIN_V8 = 1,//only support v8 detector currently
|
YOLO_ORIGIN_V8 = 1,//only support v8 detector currently
|
||||||
YOLO_POSE_V8 = 2,
|
YOLO_POSE_V8 = 2,
|
||||||
YOLO_CLS_V8 = 3
|
YOLO_CLS_V8 = 3,
|
||||||
|
YOLO_ORIGIN_V8_HALF = 4,
|
||||||
|
YOLO_POSE_V8_HALF = 5,
|
||||||
|
YOLO_CLS_V8_HALF = 6
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@ -82,13 +82,15 @@ int read_coco_yaml(DCSP_CORE*& p)
|
|||||||
|
|
||||||
int main()
|
int main()
|
||||||
{
|
{
|
||||||
DCSP_CORE* p1 = new DCSP_CORE;
|
DCSP_CORE* yoloDetector = new DCSP_CORE;
|
||||||
std::string model_path = "yolov8n.onnx";
|
std::string model_path = "yolov8n.onnx";
|
||||||
read_coco_yaml(p1);
|
read_coco_yaml(yoloDetector);
|
||||||
// GPU inference
|
// GPU FP32 inference
|
||||||
DCSP_INIT_PARAM params{ model_path, YOLO_ORIGIN_V8, {640, 640}, 0.1, 0.5, true };
|
DCSP_INIT_PARAM params{ model_path, YOLO_ORIGIN_V8, {640, 640}, 0.1, 0.5, true };
|
||||||
|
// GPU FP16 inference
|
||||||
|
// DCSP_INIT_PARAM params{ model_path, YOLO_ORIGIN_V8_HALF, {640, 640}, 0.1, 0.5, true };
|
||||||
// CPU inference
|
// CPU inference
|
||||||
// DCSP_INIT_PARAM params{ model_path, YOLO_ORIGIN_V8, {640, 640}, 0.1, 0.5, false };
|
// DCSP_INIT_PARAM params{ model_path, YOLO_ORIGIN_V8, {640, 640}, 0.1, 0.5, false };
|
||||||
p1->CreateSession(params);
|
yoloDetector->CreateSession(params);
|
||||||
file_iterator(p1);
|
file_iterator(yoloDetector);
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user