You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
264 lines
8.1 KiB
C++
264 lines
8.1 KiB
C++
#include "pch.h"
|
|
#include "OCRConfig.h"
|
|
#include "EasyOCR_Detector.h"
|
|
|
|
cv::Mat uns::EasyOCR_Detector::NormalizeMeanVariance(const cv::Mat& in)
|
|
{
|
|
cv::Mat img;
|
|
in.convertTo(img, CV_32FC3);
|
|
cv::Scalar mean(0.485f * 255, 0.456f * 255, 0.406f * 255);
|
|
cv::Scalar var(0.229f * 255, 0.224f * 255, 0.225f * 255);
|
|
img -= mean;
|
|
img /= var;
|
|
return img;
|
|
}
|
|
|
|
void uns::EasyOCR_Detector::AdjustResultCoordinates(EOCRD_Rects& polys, float ratioW, float ratioH, float ratioNet)
|
|
{
|
|
for (auto& poly : polys)
|
|
{
|
|
for (auto& pt : poly)
|
|
{
|
|
pt.x *= ratioW * ratioNet;
|
|
pt.y *= ratioH * ratioNet;
|
|
}
|
|
}
|
|
}
|
|
|
|
void uns::EasyOCR_Detector::ResizeAspectRatio(const cv::Mat& src, cv::Mat& dst, float squareSize, float magRatio, float& ratio, cv::Size& heatmapSize)
|
|
{
|
|
int h = src.rows, w = src.cols, c = src.channels();
|
|
float target = magRatio * std::max(h, w);
|
|
if (target > squareSize)
|
|
target = squareSize;
|
|
ratio = target / std::max(h, w);
|
|
|
|
int targetH = int(h * ratio), targetW = int(w * ratio);
|
|
cv::resize(src, dst, cv::Size(targetW, targetH));
|
|
|
|
int h32 = (targetH + 31) / 32 * 32;
|
|
int w32 = (targetW + 31) / 32 * 32;
|
|
cv::Mat canvas = cv::Mat::zeros(h32, w32, src.type());
|
|
dst.copyTo(canvas(cv::Rect(0, 0, targetW, targetH)));
|
|
dst = canvas;
|
|
heatmapSize = cv::Size(w32 / 2, h32 / 2);
|
|
}
|
|
|
|
bool uns::EasyOCR_Detector::GetDetBoxesCore(const cv::Mat& textmap, const cv::Mat& linkmap, float textThresh, float linkThresh, float lowText, EOCRD_Rects& boxes, cv::Mat& labels, std::vector<int>& mapper, bool estimateNumChars)
|
|
{
|
|
cv::Mat tmap = textmap.clone(), lmap = linkmap.clone();
|
|
int H = tmap.rows, W = tmap.cols;
|
|
// 1. ¶þÖµ»¯ & ºÏ²¢
|
|
cv::Mat textScore, linkScore;
|
|
cv::threshold(tmap, textScore, lowText, 1, cv::THRESH_BINARY);
|
|
cv::threshold(lmap, linkScore, linkThresh, 1, cv::THRESH_BINARY);
|
|
cv::Mat combined;
|
|
cv::add(textScore, linkScore, combined);
|
|
combined = cv::min(combined, 1);
|
|
// 2. Á¬Í¨Óò
|
|
int nLabels = 0;
|
|
cv::Mat stats, centroids;
|
|
cv::Mat combined8u;
|
|
combined.convertTo(combined8u, CV_8U);
|
|
try
|
|
{
|
|
nLabels = cv::connectedComponentsWithStats(combined8u, labels, stats, centroids, 4);
|
|
}
|
|
catch (cv::Exception e)
|
|
{
|
|
return false;
|
|
}
|
|
// 3. ±éÀúÿ¸ö label
|
|
for (int k = 1; k < nLabels; ++k)
|
|
{
|
|
int area = stats.at<int>(k, cv::CC_STAT_AREA);
|
|
if (area < 10)
|
|
continue;
|
|
// Îı¾ãÐÖµ¹ýÂË
|
|
cv::Mat mask = (labels == k);
|
|
double maxVal;
|
|
cv::minMaxLoc(tmap, nullptr, &maxVal, nullptr, nullptr, mask);
|
|
if (maxVal < textThresh)
|
|
continue;
|
|
// ¹¹½¨ segmap
|
|
cv::Mat segmap = cv::Mat::zeros(H, W, CV_8UC1);
|
|
segmap.setTo(255, labels == k);
|
|
mapper.push_back(k);
|
|
// ɾ³ý link ÇøÓò
|
|
segmap.setTo(0, (linkScore == 1) & (textScore == 0));
|
|
// ÅòÕÍ
|
|
int x = stats.at<int>(k, cv::CC_STAT_LEFT);
|
|
int y = stats.at<int>(k, cv::CC_STAT_TOP);
|
|
int wbox = stats.at<int>(k, cv::CC_STAT_WIDTH);
|
|
int hbox = stats.at<int>(k, cv::CC_STAT_HEIGHT);
|
|
int niter = int(std::sqrt(area * std::min(wbox, hbox) / float(wbox * hbox)) * 2);
|
|
int sx = std::max(0, x - niter), sy = std::max(0, y - niter);
|
|
int ex = std::min(W, x + wbox + niter + 1), ey = std::min(H, y + hbox + niter + 1);
|
|
cv::Mat kernel = cv::getStructuringElement(cv::MORPH_RECT, cv::Size(1 + niter, 1 + niter));
|
|
cv::dilate(segmap(cv::Rect(sx, sy, ex - sx, ey - sy)), segmap(cv::Rect(sx, sy, ex - sx, ey - sy)), kernel);
|
|
// ÂÖÀªÄâºÏ×îСÍâ½Ó¾ØÐÎ
|
|
std::vector<cv::Point> pts;
|
|
cv::findNonZero(segmap, pts);
|
|
cv::RotatedRect rect = cv::minAreaRect(pts);
|
|
cv::Point2f boxPts[4];
|
|
rect.points(boxPts);
|
|
std::vector<cv::Point2f> box(boxPts, boxPts + 4);
|
|
// diamond->rect
|
|
float wlen = (float)cv::norm(box[0] - box[1]);
|
|
float hlen = (float)cv::norm(box[1] - box[2]);
|
|
float ratio = std::max(wlen, hlen) / (std::min(wlen, hlen) + 1e-5f);
|
|
if (std::abs(1 - ratio) <= 0.1f)
|
|
{
|
|
int minx = W, maxx = 0, miny = H, maxy = 0;
|
|
for (auto& p : pts)
|
|
{
|
|
minx = std::min(minx, p.x); maxx = std::max(maxx, p.x);
|
|
miny = std::min(miny, p.y); maxy = std::max(maxy, p.y);
|
|
}
|
|
box =
|
|
{
|
|
{ float(minx),float(miny) },
|
|
{ float(maxx),float(miny) },
|
|
{ float(maxx),float(maxy) },
|
|
{ float(minx),float(maxy) }
|
|
};
|
|
}
|
|
// ˳ʱÕëÆðµã
|
|
int start = 0;
|
|
float minSum = box[0].x + box[0].y;
|
|
for (int i = 1; i < 4; i++)
|
|
{
|
|
float s = box[i].x + box[i].y;
|
|
if (s < minSum)
|
|
{
|
|
minSum = s;
|
|
start = i;
|
|
}
|
|
}
|
|
std::rotate(box.begin(), box.begin() + start, box.end());
|
|
boxes.push_back(box);
|
|
}
|
|
return (!boxes.empty());
|
|
}
|
|
|
|
uns::EasyOCR_Detector::EasyOCR_Detector()
|
|
{
|
|
ort_inited = false;
|
|
ort_cpu_session = nullptr;
|
|
model_path = G_OCRConfig.GetDetectModelPath();
|
|
ort = OrtGetApiBase()->GetApi(ORT_API_VERSION);
|
|
}
|
|
|
|
bool uns::EasyOCR_Detector::Init()
|
|
{
|
|
if (ort_inited)
|
|
return true;
|
|
if (!RecheckModelInfo())
|
|
return false;
|
|
try
|
|
{
|
|
ort_env = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "EasyOCR_Detector");
|
|
bool fallback_to_cpu = false;
|
|
if (!OCRToolBox::AutoSelectEP(ort, ort_session_options, fallback_to_cpu))
|
|
return false;
|
|
OCRToolBox::InitOrtSessionOptions(ort_session_options);
|
|
if ((G_OCRConfig.GetGPUUsage() == easyocr::GPUUsage::CPUOnly) || fallback_to_cpu) //ʹÓÃCPUÔò³õʼ»¯cpu session
|
|
{
|
|
ort_cpu_session = new Ort::Session(ort_env, model_path.c_str(), ort_session_options);
|
|
//ͨ¹ýCPU session»ñÈ¡ÊäÈëÊä³öÃû
|
|
OCRToolBox::GetInputOutputNames(ort_cpu_session, input_names, input_ns, output_names, output_ns);
|
|
}
|
|
else
|
|
{
|
|
//ͨ¹ýÁÙʱsession»ñÈ¡ÊäÈëÊä³öÃû£¨CUDA£ºÏ̲߳»°²È«£©
|
|
Ort::Session ort_session(ort_env, model_path.c_str(), ort_session_options);
|
|
OCRToolBox::GetInputOutputNames(&ort_session, input_names, input_ns, output_names, output_ns);
|
|
}
|
|
ort_inited = true;
|
|
return true;
|
|
}
|
|
catch (...)
|
|
{
|
|
return false;
|
|
}
|
|
}
|
|
|
|
bool uns::EasyOCR_Detector::UnInit()
|
|
{
|
|
try
|
|
{
|
|
if (ort_cpu_session != nullptr)
|
|
delete ort_cpu_session;
|
|
ort_cpu_session = nullptr;
|
|
return true;
|
|
}
|
|
catch (...)
|
|
{
|
|
return false;
|
|
}
|
|
}
|
|
|
|
bool uns::EasyOCR_Detector::RecheckModelInfo()
|
|
{
|
|
if (model_path.empty())
|
|
model_path = G_OCRConfig.GetDetectModelPath();
|
|
return OCRToolBox::CheckFile(model_path);
|
|
}
|
|
|
|
uns::EOCRD_Rects uns::EasyOCR_Detector::operator()(const cv::Mat& image)
|
|
{
|
|
// 0. check model
|
|
if (!RecheckModelInfo())
|
|
return {};
|
|
try
|
|
{
|
|
// 1. resize + normalize
|
|
cv::Mat resized;
|
|
float ratio;
|
|
cv::Size heatmapSize;
|
|
ResizeAspectRatio(image, resized, 1280.0f, 1.5f, ratio, heatmapSize);
|
|
cv::Mat input = NormalizeMeanVariance(resized);
|
|
// 2. ¹¹Ôì³É NCHW tensor
|
|
cv::dnn::blobFromImage(input, input);
|
|
std::array<int64_t, 4> shape = { 1, 3, input.size[2], input.size[3] };
|
|
Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
|
|
Ort::Value tensor = Ort::Value::CreateTensor<float>(memInfo, input.ptr<float>(), input.total(), shape.data(), shape.size());
|
|
// 3. ONNX ÍÆÀí
|
|
auto outputs = ((ort_cpu_session != nullptr) ? ort_cpu_session->Run(Ort::RunOptions{nullptr}, input_names.data(), & tensor, 1, output_names.data(), 1) : Ort::Session(ort_env, model_path.c_str(), ort_session_options).Run(Ort::RunOptions{nullptr}, input_names.data(), & tensor, 1, output_names.data(), 1));
|
|
std::vector<int64_t> outputShape = outputs[0].GetTensorTypeAndShapeInfo().GetShape();
|
|
// 4. ²ð·Ö score_text & score_link
|
|
float* outData = outputs[0].GetTensorMutableData<float>();
|
|
int H = int(outputShape[1]), W = int(outputShape[2]);
|
|
cv::Mat score_text(H, W, CV_32F);
|
|
cv::Mat score_link(H, W, CV_32F);
|
|
for (int y = 0; y < H; ++y)
|
|
{
|
|
for (int x = 0; x < W; ++x)
|
|
{
|
|
int offset = (y * W + x) * 2;
|
|
score_text.at<float>(y, x) = outData[offset + 0];
|
|
score_link.at<float>(y, x) = outData[offset + 1];
|
|
}
|
|
}
|
|
// --- 3. µÃµ½ boxes/polys (heatmap ×ø±êϵ) ---
|
|
EOCRD_Rects boxes, polys;
|
|
cv::Mat labels;
|
|
std::vector<int> mapper;
|
|
if (!GetDetBoxesCore(score_text, score_link, textThreshold, linkThreshold, lowText, boxes, labels, mapper, false))
|
|
return {};
|
|
polys = boxes;
|
|
// --- 4. ¼ÆËã×îÖÕÓ³Éä±ÈÀý ---
|
|
float invRatio = 1.0f / ratio;
|
|
float ratioNetW = float(resized.cols) / float(heatmapSize.width);
|
|
float ratioNetH = float(resized.rows) / float(heatmapSize.height);
|
|
// ͨ³£ heatmapSize = (resized.cols/2, resized.rows/2)£¬ËùÒÔ ratioNetW/H ¡Ö 2
|
|
// --- 5. Ó³É仨Ôͼ ---
|
|
AdjustResultCoordinates(polys, invRatio, invRatio, ratioNetW);
|
|
return polys;
|
|
}
|
|
catch (...)
|
|
{
|
|
return {};
|
|
}
|
|
}
|