You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

390 lines
11 KiB
C++

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include "pch.h"
#include "OCRCharset.h"
#include "EasyOCR_Recognizer.h"
uns::EasyOCR_Recognizer::NormalizePAD::Size3i uns::EasyOCR_Recognizer::NormalizePAD::Size3i::operator=(const Size3i& obj)
{
d0 = obj.d0;
d1 = obj.d1;
d2 = obj.d2;
return (*this);
}
uns::EasyOCR_Recognizer::NormalizePAD::NormalizePAD(Size3i max_size, const std::string& PAD_type)
{
this->max_size = max_size;
this->PAD_type = PAD_type;
max_width_half = max_size.d2 / 2; // 计算宽度的一半,用于可选操作
}
cv::Mat uns::EasyOCR_Recognizer::NormalizePAD::operator()(const cv::Mat& input_img) const
{
// 将原图转换为32位浮点型并归一化到[0,1]
cv::Mat img;
input_img.convertTo(img, CV_32F, 1.0 / 255); // line 10: img = toTensor
img = (img - 0.5f) / 0.5f; // line 11: img.sub_(0.5).div_(0.5)
int h = img.rows; // 获取图像高度
int w = img.cols; // 获取图像宽度
int c = img.channels(); // 获取通道数灰度图默认为1
// 创建目标大小的全零Mat类型为32F尺寸为max_size.d1 x max_size.d2
cv::Mat pad_img = cv::Mat::zeros(max_size.d1, max_size.d2, CV_32FC(c)); // line 13
// 将原图像拷贝到pad_img的左上角区域实现右侧填充
img.copyTo(pad_img(cv::Rect(0, 0, w, h))); // line 14
// 如果目标宽度大于原图宽度,则使用最后一列像素进行扩展填充
if (max_size.d2 != w)
{ // line 15
cv::Mat last_col = img.col(w - 1);
cv::Mat border;
cv::repeat(last_col, 1, max_size.d2 - w, border); // 重复最后一列填充
border.copyTo(pad_img(cv::Rect(w, 0, max_size.d2 - w, h)));
}
return pad_img; // 返回处理后的浮点张量
}
cv::Mat uns::EasyOCR_Recognizer::AlignCollate::AdjustContrastGrey(const cv::Mat& img_in, double target) const
{
double contrast;
int high, low;
ContrastGrey(img_in, contrast, high, low);
cv::Mat img = img_in.clone();
if (contrast < target)
{
cv::Mat img_i;
img.convertTo(img_i, CV_32S);
double ratio = 200.0 / std::max(10, high - low);
img_i = (img_i - low + 25) * ratio;
// 将像素值限制在[0,255]范围并转换回8位
img_i.forEach<int>([] (int& pixel, const int*)
{
pixel = std::clamp(pixel, 0, 255);
});
img_i.convertTo(img, CV_8U);
}
return img;
}
void uns::EasyOCR_Recognizer::AlignCollate::ContrastGrey(const cv::Mat& img, double& contrast, int& high, int& low) const
{
// 将Mat图像数据复制到一个连续的vector<int>中,以便排序
std::vector<int> pixels;
pixels.reserve(img.rows * img.cols); // 预分配空间以提高效率
for (int i = 0; i < img.rows; ++i)
{
const uchar* row_ptr = img.ptr<uchar>(i);
for (int j = 0; j < img.cols; ++j)
pixels.push_back(static_cast<int>(row_ptr[j]));
}
// 对像素值进行排序,便于获取百分位数
std::sort(pixels.begin(), pixels.end());
// 计算90%的索引位置与Python np.percentile保持一致
int idx90 = static_cast<int>(0.9 * (pixels.size() - 1));
int idx10 = static_cast<int>(0.1 * (pixels.size() - 1));
high = pixels[idx90];
low = pixels[idx10];
// 计算contrast: (high - low) / max(10, high + low)
contrast = double(high - low) / double(std::max(10, high + low));
}
uns::EasyOCR_Recognizer::AlignCollate::AlignCollate(int imgH, int imgW, bool keep_ratio_with_pad, double adjust_contrast)
{
this->imgH = imgH;
this->imgW = imgW;
this->adjust_contrast = adjust_contrast;
this->keep_ratio_with_pad = keep_ratio_with_pad;
}
cv::Mat uns::EasyOCR_Recognizer::AlignCollate::operator()(const std::vector<cv::Mat>& batch) const
{
std::vector<cv::Mat> resized_images;
// 创建NormalizePAD实例用于归一化和填充
NormalizePAD transform({ 1, imgH, imgW });
for (const cv::Mat& image : batch)
{
cv::Mat working;
if (adjust_contrast > 0)
{
cv::Mat grey;
if (image.channels() > 1)
cv::cvtColor(image, grey, cv::COLOR_BGR2GRAY);
else
grey = image;
working = AdjustContrastGrey(grey, adjust_contrast);
}
else
working = image;
int w = working.cols;
int h = working.rows;
double ratio = double(w) / h;
int resized_w = static_cast<int>(std::ceil(imgH * ratio));
if (resized_w > imgW)
resized_w = imgW;
cv::Mat resized;
cv::resize(working, resized, cv::Size(resized_w, imgH), 0, 0, cv::INTER_CUBIC);
cv::Mat tensor = transform(resized);
resized_images.push_back(tensor);
}
cv::Mat blob;
cv::dnn::blobFromImages(resized_images, blob);
return blob;
}
float uns::EasyOCR_Recognizer::CustomMean(const VecFloat& x)
{
size_t N = x.size();
if (N == 0)
return 0.0f;
// 1. 计算所有元素的乘积
double prod = 1.0;
for (float v : x)
if (v != 0)
prod *= static_cast<double>(v);
// 2. 计算指数 2.0 / sqrt(N)
double exponent = 2.0 / std::sqrt(static_cast<double>(N));
// 3. 返回 prod 的 exponent 次幂
return static_cast<float>(std::pow(prod, exponent));
}
cv::Mat uns::EasyOCR_Recognizer::Preprocess(const cv::Mat& img) const
{
if (img.empty())
return {}; //此处不适合抛出异常,使用空图像终止后级的处理即可
cv::Mat gray;
int ch = img.channels();
// case 2: BGR 彩色图3 通道)
if (ch == 3)
cv::cvtColor(img, gray, cv::COLOR_BGR2GRAY);
// case 3: RGBA 彩色图4 通道)
else if (ch == 4)
{
// 去掉 alpha 通道,把 BGRA → GRAY
cv::Mat bgr;
cv::cvtColor(img, gray, cv::COLOR_BGRA2GRAY);
}
else // image 本身可能是 (h×w) 或者 (h×w×1),对我们来说都当灰度处理
gray = img;
int width = gray.cols;
int height = gray.rows;
int model_height = 64, model_width = 0;
float ratio = static_cast<float>(width) / static_cast<float>(height);
cv::Mat resized;
if (ratio < 1.0f)
{
// 垂直文本情况,使用 calculate_ratio 保证高度为 model_height
float adj_ratio = CalculateRatio(width, height);
model_width = static_cast<int>(model_height * adj_ratio);
cv::resize(gray, resized, cv::Size(model_height, model_width), 0, 0, cv::INTER_LINEAR);
ratio = adj_ratio;
}
else
{
// 横向文本情况,高度为 model_height
model_width = static_cast<int>(model_height * ratio);
cv::resize(gray, resized, cv::Size(model_width, model_height), 0, 0, cv::INTER_LINEAR);
}
AlignCollate alignCollate(model_height, model_width, true, 0.5);
return alignCollate({ resized });
}
float uns::EasyOCR_Recognizer::CalculateRatio(int width, int height) const
{
float ratio = static_cast<float>(width) / static_cast<float>(height);
if (ratio < 1.0f)
ratio = 1.0f / ratio;
return ratio;
}
uns::VecFloat uns::EasyOCR_Recognizer::SoftMAX(const float* logits, int C) const
{
// 找到最大值以稳定数值
float m = logits[0];
for (int i = 1; i < C; ++i)
m = std::max(m, logits[i]);
// 计算 exp(logit - m)
std::vector<float> exps(C);
float sum = 0.f;
for (int i = 0; i < C; ++i)
{
exps[i] = std::exp(logits[i] - m);
sum += exps[i];
}
// 归一化
for (int i = 0; i < C; ++i)
exps[i] /= (sum > 1e-6f ? sum : 1e-6f);
return exps;
}
void uns::EasyOCR_Recognizer::PostprocessONNXOutput(const Ort::Value& outputs, int N, int T, int C, VecInt& out_indices, VecFloat& out_probs, const VecInt ignore_idx)
{
// 指针访问底层数据
const float* data = outputs.GetTensorData<float>();
out_indices.clear();
out_probs.clear();
// 临时存储每步概率
std::vector<float> probs;
probs.reserve(C);
// 遍历每个样本、每个时间步
for (int n = 0; n < N; ++n)
{
for (int t = 0; t < T; ++t)
{
// logits 起始位置: ((n * T) + t) * C
const float* logits = data + ((size_t)n * T + t) * C;
// 1) Softmax
probs = SoftMAX(logits, C);
// 2) 忽略 ignore_idx
if (!ignore_idx.empty())
for (const auto& idx : ignore_idx)
probs[idx] = 0.f;
// 3) 再次归一化
float sum = 0.f;
for (int c = 0; c < C; ++c)
sum += probs[c];
if (sum > 1e-6f)
{
for (int c = 0; c < C; ++c)
probs[c] /= sum;
}
// 4) 取最大索引
int best = 0;
float best_prob = 0.0f;
for (int c = 1; c < C; ++c)
{
if (probs[c] > probs[best])
{
best = c;
best_prob = probs[c];
}
}
out_indices.push_back(best);
out_probs.push_back(best_prob);
}
}
}
uns::EasyOCR_Recognizer::EasyOCR_Recognizer()
{
ort_inited = false;
ort_cpu_session = nullptr;
model_path = G_OCRConfig.GetRecognizeModelPath();
ort = OrtGetApiBase()->GetApi(ORT_API_VERSION);
}
bool uns::EasyOCR_Recognizer::Init()
{
if (ort_inited)
return true;
if (!RecheckModelInfo())
return false;
try
{
ort_env = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "EasyOCR_Recognizer");
bool fallback_to_cpu = false;
if (!OCRToolBox::AutoSelectEP(ort, ort_session_options, fallback_to_cpu))
return false;
OCRToolBox::InitOrtSessionOptions(ort_session_options);
if ((G_OCRConfig.GetGPUUsage() == easyocr::GPUUsage::CPUOnly) || fallback_to_cpu) //使用CPU则初始化cpu session
{
ort_cpu_session = new Ort::Session(ort_env, model_path.c_str(), ort_session_options);
//通过CPU session获取输入输出名
OCRToolBox::GetInputOutputNames(ort_cpu_session, input_names, input_ns, output_names, output_ns);
}
else
{
//通过临时session获取输入输出名CUDA线程不安全
Ort::Session ort_session(ort_env, model_path.c_str(), ort_session_options);
OCRToolBox::GetInputOutputNames(&ort_session, input_names, input_ns, output_names, output_ns);
}
ort_inited = true;
return true;
}
catch (...)
{
return false;
}
}
bool uns::EasyOCR_Recognizer::UnInit()
{
try
{
if (ort_cpu_session != nullptr)
delete ort_cpu_session;
ort_cpu_session = nullptr;
return true;
}
catch (...)
{
return false;
}
}
bool uns::EasyOCR_Recognizer::RecheckModelInfo()
{
if (model_path.empty())
model_path = G_OCRConfig.GetRecognizeModelPath();
return OCRToolBox::CheckFile(model_path);
}
uns::EOCR_Result uns::EasyOCR_Recognizer::operator()(const cv::Mat& image)
{
try
{
if (!RecheckModelInfo())
return { L"", -1.0f };
cv::Mat input = Preprocess(image);
if (input.empty())
return { L"", 0.0f };
std::array<int64_t, 4> inputShape = { 1, 1, input.size[2], input.size[3] };
Ort::MemoryInfo memInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
Ort::Value inputTensor = Ort::Value::CreateTensor<float>(memInfo, input.ptr<float>(), input.total(), inputShape.data(), inputShape.size());
auto outputs = ((ort_cpu_session != nullptr) ? ort_cpu_session->Run(Ort::RunOptions{nullptr}, input_names.data(), & inputTensor, 1, output_names.data(), 1) : Ort::Session(ort_env, model_path.c_str(), ort_session_options).Run(Ort::RunOptions{nullptr}, input_names.data(), & inputTensor, 1, output_names.data(), 1));
// 输出 shape: [1, T, C]
auto& outVal = outputs.front();
auto info = outVal.GetTensorTypeAndShapeInfo();
auto shape = info.GetShape(); // {1, T, C}
int N = (int)shape[0], T = (int)shape[1], C = (int)shape[2];
float* data = outVal.GetTensorMutableData<float>();
// greedy pick & softmax
std::vector<int> indices(T);
std::vector<float> maxProbs(T);
PostprocessONNXOutput(outputs[0], N, T, C, indices, maxProbs);
// 解码
std::wstring text = OCRCharset::GetString(indices);
// 置信度
float conf = CustomMean(maxProbs);
return { text, conf };
}
catch (...)
{
return { L"", -2.0f };
}
}
uns::EOCR_ResultSet uns::EasyOCR_Recognizer::operator()(const cv::Mat& full_image, const EOCRD_Rects& rects)
{
if (!RecheckModelInfo())
return {};
try
{
EOCR_ResultSet result_set;
for (size_t i = 0; i < rects.size(); ++i)
{
// 将多边形转为最小外接矩形并裁剪
cv::Rect rect = cv::boundingRect(rects[i]);
rect &= cv::Rect(0, 0, full_image.cols, full_image.rows); // 裁剪到图像范围
cv::Mat crop = full_image(rect);
if (crop.empty())
continue;
auto [text, conf] = (*this)(crop);
result_set.insert({ i, { text, conf, rect } });
}
return result_set;
}
catch (...)
{
return {};
}
}