You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
55 lines
1.7 KiB
C++
55 lines
1.7 KiB
C++
//
|
|
// Created by tunm on 2023/2/11.
|
|
//
|
|
|
|
|
|
#include "basic_types.h"
|
|
#include "../test_settings.h"
|
|
#include "opencv2/opencv.hpp"
|
|
#include "nn_implementation_module/classification/all.h"
|
|
#include "basic_types.h"
|
|
#include "utils.h"
|
|
|
|
using namespace hyper;
|
|
|
|
TEST_CASE("test_Classification", "[nn_cls]") {
|
|
PRINT_SPLIT_LINE
|
|
LOGD("[UnitTest]->Classification Model");
|
|
|
|
std::string model_path = GET_DATA("models/r2_mobile/litemodel_cls_96xh.mnn");
|
|
|
|
std::vector<std::string> predict_images_list = {
|
|
GET_DATA("images/align/1.jpg"),
|
|
GET_DATA("images/align/3.jpg"),
|
|
GET_DATA("images/align/5.jpg"),
|
|
};
|
|
std::vector<PlateColor> predict_results_cls = {
|
|
PlateColor::BLUE, PlateColor::YELLOW, PlateColor::GREEN,
|
|
};
|
|
std::vector<float> predict_results_confidence = {
|
|
0.9999293f, 0.8975975f, 0.9997952f
|
|
};
|
|
|
|
CHECK(predict_results_confidence.size() == predict_results_cls.size());
|
|
CHECK(predict_results_confidence.size() == predict_images_list.size());
|
|
|
|
ClassificationEngine clsEngine;
|
|
auto ret = clsEngine.Initialize(model_path, cv::Size_<int>(96, 96));
|
|
CHECK(ret == InferenceHelper::kRetOk);
|
|
|
|
SECTION("test_ClassificationModelPredict") {
|
|
for (int i = 0; i < predict_images_list.size(); ++i) {
|
|
cv::Mat img = cv::imread(predict_images_list[i]);
|
|
CHECK(!img.empty());
|
|
CHECK(img.cols == 96);
|
|
CHECK(img.rows == 96);
|
|
ret = clsEngine.Inference(img);
|
|
CHECK(ret == InferenceHelper::kRetOk);
|
|
CHECK(PlateColor(clsEngine.getMOutputColor()) == predict_results_cls[i]);
|
|
CHECK(clsEngine.getMOutputMaxConfidence() == Approx(predict_results_confidence[i]).epsilon(0.001));
|
|
}
|
|
}
|
|
|
|
|
|
}
|