You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

264 lines
8.1 KiB
C++

#include "pch.h"
#include "OCRConfig.h"
#include "EasyOCR_Detector.h"
cv::Mat uns::EasyOCR_Detector::NormalizeMeanVariance(const cv::Mat& in)
{
cv::Mat img;
in.convertTo(img, CV_32FC3);
cv::Scalar mean(0.485f * 255, 0.456f * 255, 0.406f * 255);
cv::Scalar var(0.229f * 255, 0.224f * 255, 0.225f * 255);
img -= mean;
img /= var;
return img;
}
void uns::EasyOCR_Detector::AdjustResultCoordinates(EOCRD_Rects& polys, float ratioW, float ratioH, float ratioNet)
{
for (auto& poly : polys)
{
for (auto& pt : poly)
{
pt.x *= ratioW * ratioNet;
pt.y *= ratioH * ratioNet;
}
}
}
void uns::EasyOCR_Detector::ResizeAspectRatio(const cv::Mat& src, cv::Mat& dst, float squareSize, float magRatio, float& ratio, cv::Size& heatmapSize)
{
int h = src.rows, w = src.cols, c = src.channels();
float target = magRatio * std::max(h, w);
if (target > squareSize)
target = squareSize;
ratio = target / std::max(h, w);
int targetH = int(h * ratio), targetW = int(w * ratio);
cv::resize(src, dst, cv::Size(targetW, targetH));
int h32 = (targetH + 31) / 32 * 32;
int w32 = (targetW + 31) / 32 * 32;
cv::Mat canvas = cv::Mat::zeros(h32, w32, src.type());
dst.copyTo(canvas(cv::Rect(0, 0, targetW, targetH)));
dst = canvas;
heatmapSize = cv::Size(w32 / 2, h32 / 2);
}
bool uns::EasyOCR_Detector::GetDetBoxesCore(const cv::Mat& textmap, const cv::Mat& linkmap, float textThresh, float linkThresh, float lowText, EOCRD_Rects& boxes, cv::Mat& labels, std::vector<int>& mapper, bool estimateNumChars)
{
cv::Mat tmap = textmap.clone(), lmap = linkmap.clone();
int H = tmap.rows, W = tmap.cols;
// 1. ¶þÖµ»¯ & ºÏ²¢
cv::Mat textScore, linkScore;
cv::threshold(tmap, textScore, lowText, 1, cv::THRESH_BINARY);
cv::threshold(lmap, linkScore, linkThresh, 1, cv::THRESH_BINARY);
cv::Mat combined;
cv::add(textScore, linkScore, combined);
combined = cv::min(combined, 1);
// 2. Á¬Í¨Óò
int nLabels = 0;
cv::Mat stats, centroids;
cv::Mat combined8u;
combined.convertTo(combined8u, CV_8U);
try
{
nLabels = cv::connectedComponentsWithStats(combined8u, labels, stats, centroids, 4);
}
catch (cv::Exception e)
{
return false;
}
// 3. ±éÀúÿ¸ö label
for (int k = 1; k < nLabels; ++k)
{
int area = stats.at<int>(k, cv::CC_STAT_AREA);
if (area < 10)
continue;
// Îı¾ãÐÖµ¹ýÂË
cv::Mat mask = (labels == k);
double maxVal;
cv::minMaxLoc(tmap, nullptr, &maxVal, nullptr, nullptr, mask);
if (maxVal < textThresh)
continue;
// ¹¹½¨ segmap
cv::Mat segmap = cv::Mat::zeros(H, W, CV_8UC1);
segmap.setTo(255, labels == k);
mapper.push_back(k);
// ɾ³ý link ÇøÓò
segmap.setTo(0, (linkScore == 1) & (textScore == 0));
// ÅòÕÍ
int x = stats.at<int>(k, cv::CC_STAT_LEFT);
int y = stats.at<int>(k, cv::CC_STAT_TOP);
int wbox = stats.at<int>(k, cv::CC_STAT_WIDTH);
int hbox = stats.at<int>(k, cv::CC_STAT_HEIGHT);
int niter = int(std::sqrt(area * std::min(wbox, hbox) / float(wbox * hbox)) * 2);
int sx = std::max(0, x - niter), sy = std::max(0, y - niter);
int ex = std::min(W, x + wbox + niter + 1), ey = std::min(H, y + hbox + niter + 1);
cv::Mat kernel = cv::getStructuringElement(cv::MORPH_RECT, cv::Size(1 + niter, 1 + niter));
cv::dilate(segmap(cv::Rect(sx, sy, ex - sx, ey - sy)), segmap(cv::Rect(sx, sy, ex - sx, ey - sy)), kernel);
// ÂÖÀªÄâºÏ×îСÍâ½Ó¾ØÐÎ
std::vector<cv::Point> pts;
cv::findNonZero(segmap, pts);
cv::RotatedRect rect = cv::minAreaRect(pts);
cv::Point2f boxPts[4];
rect.points(boxPts);
std::vector<cv::Point2f> box(boxPts, boxPts + 4);
// diamond->rect
float wlen = (float)cv::norm(box[0] - box[1]);
float hlen = (float)cv::norm(box[1] - box[2]);
float ratio = std::max(wlen, hlen) / (std::min(wlen, hlen) + 1e-5f);
if (std::abs(1 - ratio) <= 0.1f)
{
int minx = W, maxx = 0, miny = H, maxy = 0;
for (auto& p : pts)
{
minx = std::min(minx, p.x); maxx = std::max(maxx, p.x);
miny = std::min(miny, p.y); maxy = std::max(maxy, p.y);
}
box =
{
{ float(minx),float(miny) },
{ float(maxx),float(miny) },
{ float(maxx),float(maxy) },
{ float(minx),float(maxy) }
};
}
// ˳ʱÕëÆðµã
int start = 0;
float minSum = box[0].x + box[0].y;
for (int i = 1; i < 4; i++)
{
float s = box[i].x + box[i].y;
if (s < minSum)
{
minSum = s;
start = i;
}
}
std::rotate(box.begin(), box.begin() + start, box.end());
boxes.push_back(box);
}
return (!boxes.empty());
}
uns::EasyOCR_Detector::EasyOCR_Detector()
{
ort_inited = false;
ort_cpu_session = nullptr;
model_path = G_OCRConfig.GetDetectModelPath();
ort = OrtGetApiBase()->GetApi(ORT_API_VERSION);
}
bool uns::EasyOCR_Detector::Init()
{
if (ort_inited)
return true;
if (!RecheckModelInfo())
return false;
try
{
ort_env = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "EasyOCR_Detector");
bool fallback_to_cpu = false;
if (!OCRToolBox::AutoSelectEP(ort, ort_session_options, fallback_to_cpu))
return false;
OCRToolBox::InitOrtSessionOptions(ort_session_options);
if ((G_OCRConfig.GetGPUUsage() == easyocr::GPUUsage::CPUOnly) || fallback_to_cpu) //ʹÓÃCPUÔò³õʼ»¯cpu session
{
ort_cpu_session = new Ort::Session(ort_env, model_path.c_str(), ort_session_options);
//ͨ¹ýCPU session»ñÈ¡ÊäÈëÊä³öÃû
OCRToolBox::GetInputOutputNames(ort_cpu_session, input_names, input_ns, output_names, output_ns);
}
else
{
//ͨ¹ýÁÙʱsession»ñÈ¡ÊäÈëÊä³öÃû£¨CUDA£ºÏ̲߳»°²È«£©
Ort::Session ort_session(ort_env, model_path.c_str(), ort_session_options);
OCRToolBox::GetInputOutputNames(&ort_session, input_names, input_ns, output_names, output_ns);
}
ort_inited = true;
return true;
}
catch (...)
{
return false;
}
}
bool uns::EasyOCR_Detector::UnInit()
{
try
{
if (ort_cpu_session != nullptr)
delete ort_cpu_session;
ort_cpu_session = nullptr;
return true;
}
catch (...)
{
return false;
}
}
bool uns::EasyOCR_Detector::RecheckModelInfo()
{
if (model_path.empty())
model_path = G_OCRConfig.GetDetectModelPath();
return OCRToolBox::CheckFile(model_path);
}
uns::EOCRD_Rects uns::EasyOCR_Detector::operator()(const cv::Mat& image)
{
// 0. check model
if (!RecheckModelInfo())
return {};
try
{
// 1. resize + normalize
cv::Mat resized;
float ratio;
cv::Size heatmapSize;
ResizeAspectRatio(image, resized, 1280.0f, 1.5f, ratio, heatmapSize);
cv::Mat input = NormalizeMeanVariance(resized);
// 2. ¹¹Ôì³É NCHW tensor
cv::dnn::blobFromImage(input, input);
std::array<int64_t, 4> shape = { 1, 3, input.size[2], input.size[3] };
Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
Ort::Value tensor = Ort::Value::CreateTensor<float>(memInfo, input.ptr<float>(), input.total(), shape.data(), shape.size());
// 3. ONNX ÍÆÀí
auto outputs = ((ort_cpu_session != nullptr) ? ort_cpu_session->Run(Ort::RunOptions{nullptr}, input_names.data(), & tensor, 1, output_names.data(), 1) : Ort::Session(ort_env, model_path.c_str(), ort_session_options).Run(Ort::RunOptions{nullptr}, input_names.data(), & tensor, 1, output_names.data(), 1));
std::vector<int64_t> outputShape = outputs[0].GetTensorTypeAndShapeInfo().GetShape();
// 4. ²ð·Ö score_text & score_link
float* outData = outputs[0].GetTensorMutableData<float>();
int H = int(outputShape[1]), W = int(outputShape[2]);
cv::Mat score_text(H, W, CV_32F);
cv::Mat score_link(H, W, CV_32F);
for (int y = 0; y < H; ++y)
{
for (int x = 0; x < W; ++x)
{
int offset = (y * W + x) * 2;
score_text.at<float>(y, x) = outData[offset + 0];
score_link.at<float>(y, x) = outData[offset + 1];
}
}
// --- 3. µÃµ½ boxes/polys (heatmap ×ø±êϵ) ---
EOCRD_Rects boxes, polys;
cv::Mat labels;
std::vector<int> mapper;
if (!GetDetBoxesCore(score_text, score_link, textThreshold, linkThreshold, lowText, boxes, labels, mapper, false))
return {};
polys = boxes;
// --- 4. ¼ÆËã×îÖÕÓ³Éä±ÈÀý ---
float invRatio = 1.0f / ratio;
float ratioNetW = float(resized.cols) / float(heatmapSize.width);
float ratioNetH = float(resized.rows) / float(heatmapSize.height);
// ͨ³£ heatmapSize = (resized.cols/2, resized.rows/2)£¬ËùÒÔ ratioNetW/H ¡Ö 2
// --- 5. Ó³Éä»ØÔ­Í¼ ---
AdjustResultCoordinates(polys, invRatio, invRatio, ratioNetW);
return polys;
}
catch (...)
{
return {};
}
}