You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

57 lines
2.2 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
主程序作者:王昱博
车牌识别系统:
使用OCR技术对车牌号码进行识别
使用图像分类AI对车牌种类进行区分
"""
import cv2
from ocr import OCR
from cut_image import ImageCutter
from classification_ai import ClassificationAI
classify_models = ['.\\classify_model\\0.0625.pkl', '.\\classify_model\\0.0625-2.pkl', '.\\classify_model\\0.125.pkl']
def train(train_set_path: str, export_path: str) -> None:
ClassificationAI.TrainAI(train_set_path, export_path)
def main(classify_model_index: int, image_path: str) -> None:
global classify_models
origin_image, gray_image = ImageCutter.ImagePreProcess(image_path)
lpr_text, lpr_conf, cut_image = OCR.RecognizeLicensePlate2(origin_image)
if cut_image is None:
cut_image = ImageCutter.CutPlateRect(origin_image, gray_image)
ocr_text, ocr_type = OCR.RecognizeLicensePlate(cut_image, lpr_text)
if lpr_text is None:
lpr_text = ocr_text
lpr_conf = None
ai_type, ai_conf = ClassificationAI.PredictImage(cut_image, classify_models[classify_model_index])
print(f'识别完成,以下为识别结果:\n车牌号:{lpr_text} [置信度:{lpr_conf}]\n车牌类型:\n\t{ocr_type}(OCR推测)\n\t{ai_type}(AI分类识别)\n\tAI识别置信度{ai_conf}')
if __name__ == '__main__':
result = input('请选择运行模式(训练(t)/识别(r)): ')
if result == 't' or result == 'T':
data_path = input('输入训练集路径: ')
export_path = input('输入模型保存路径: ')
try:
train(data_path, export_path)
except Exception as e:
print(f'训练过程中发生错误: {e}')
else:
print('模型已成功训练')
finally:
print('训练结束')
elif result == 'r' or result == 'R':
model_index = input('选择使用的识别模型(1/2/3): ')
image_path = input('输入图片路径: ')
if (not model_index.isdigit()) or (int(model_index) < 1) or (int(model_index) > 3):
print('输入有误')
else:
main(int(model_index), image_path)
else:
print('输入有误')