diff --git a/classification_ai.py b/classification_ai.py index 2e0f80f..8ba0f50 100644 --- a/classification_ai.py +++ b/classification_ai.py @@ -28,6 +28,25 @@ class ClassificationAI: def ConvertImage(cv_img: cv2.Mat) -> Image.Image: return Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)).resize((460, 460)) + @staticmethod + def ConvertClassifyResult(cla: str) -> str: + if cla == 'ForeignV': + return '外籍车辆' + elif cla == 'In-fieldV': + return '场内车辆' + elif cla == 'large-scaleNewenergyV': + return '大型新能源车辆' + elif cla == 'MediumLarge-sizedV': + return '中/大型车辆' + elif cla == 'MilitaryPoliceEmergencyV': + return '军/警/应急车辆' + elif cla == 'SmallCar': + return '小型轿车' + elif cla == 'SmallNewEnergyV': + return '小型新能源轿车' + else: + return '未知' + @classmethod def TrainAI(cls, data_set_path: str, export_path: str) -> None: blocks = (ImageBlock, CategoryBlock) @@ -39,7 +58,7 @@ class ClassificationAI: get_y=parent_label, item_tfms=Resize(460), batch_tfms=[*aug_transforms(size=224, min_scale=0.75), Normalize.from_stats(*imagenet_stats)] - ).dataloaders(data_set_path, num_workers=4, bs=batch_size) + ).dataloaders(data_set_path, num_workers=0, bs=batch_size) model = vision_learner(dls, resnet34, metrics=error_rate) model.fine_tune(5, freeze_epochs=3) # 5 - 训练的轮次, 3 - 冻结的轮次 model.export(Path(export_path) / 'model.pkl') @@ -58,4 +77,5 @@ class ClassificationAI: confidence = float(outputs) else: confidence = float(outputs[pred_idx]) + pred_class = cls.ConvertClassifyResult(pred_class) return pred_class, confidence diff --git a/main.py b/main.py index c257a41..3afd190 100644 --- a/main.py +++ b/main.py @@ -22,15 +22,19 @@ def main(classify_model_index: int, image_path: str) -> None: global classify_models origin_image, gray_image = ImageCutter.ImagePreProcess(image_path) lpr_text, lpr_conf, cut_image = OCR.RecognizeLicensePlate2(origin_image) - ocr_text, ocr_type = OCR.RecognizeLicensePlate(cut_image) + if cut_image is None: + cut_image = ImageCutter.CutPlateRect(origin_image, gray_image) + ocr_text, ocr_type = OCR.RecognizeLicensePlate(cut_image, lpr_text) + if lpr_text is None: + lpr_text = ocr_text + lpr_conf = None ai_type, ai_conf = ClassificationAI.PredictImage(cut_image, classify_models[classify_model_index]) print(f'识别完成,以下为识别结果:\n车牌号:{lpr_text} [置信度:{lpr_conf}]\n车牌类型:\n\t{ocr_type}(OCR推测)\n\t{ai_type}(AI分类识别)\n\tAI识别置信度:{ai_conf}') - cv2.waitKey(0) if __name__ == '__main__': - result = input('请选择运行模式(训练(y)/识别(n)): ') - if result == 'y' or result == 'Y': + result = input('请选择运行模式(训练(t)/识别(r)): ') + if result == 't' or result == 'T': data_path = input('输入训练集路径: ') export_path = input('输入模型保存路径: ') try: @@ -41,7 +45,7 @@ if __name__ == '__main__': print('模型已成功训练') finally: print('训练结束') - elif result == 'n' or result == 'N': + elif result == 'r' or result == 'R': model_index = input('选择使用的识别模型(1/2/3): ') image_path = input('输入图片路径: ') if (not model_index.isdigit()) or (int(model_index) < 1) or (int(model_index) > 3): diff --git a/ocr.py b/ocr.py index fdb9c15..f3ecfe8 100644 --- a/ocr.py +++ b/ocr.py @@ -19,20 +19,24 @@ class OCR: return text @classmethod - def RecognizeLicensePlate(cls, image: cv2.Mat) -> tuple: + def RecognizeLicensePlate(cls, image: cv2.Mat, lpr_text: str) -> tuple: reader = easyocr.Reader(['ch_sim', 'en'], model_storage_directory='./easyocr_model') result = reader.readtext(image) license_plate = "" for res in result: license_plate += res[-2] # 如果车牌号码是两行的,按行识别出来再拼接起来 + if lpr_text is not None: + license_plate = lpr_text if '\u8b66' in license_plate: - car_type = 'police' + car_type = '警用车辆' elif '\u573a\u5185' in license_plate: - car_type = 'internal' + car_type = '场内车辆' elif '\u6302' in license_plate: - car_type = 'bigCar' + car_type = '挂车/半挂车' + elif len(license_plate) > 7: + car_type = '新能源车辆' else: - car_type = 'smallCar' + car_type = '无法推测' return license_plate, car_type @classmethod @@ -43,4 +47,4 @@ class OCR: x0, y0, x1, y1 = box cut_image = image[y0:y1, x0:x1] return code, conf, cut_image - return None, None, image + return None, None, None