Ver2.0提交

master
UnknownObject 2 years ago
parent 3cb1186b8a
commit fa7b5af994

@ -28,6 +28,25 @@ class ClassificationAI:
def ConvertImage(cv_img: cv2.Mat) -> Image.Image:
return Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)).resize((460, 460))
@staticmethod
def ConvertClassifyResult(cla: str) -> str:
if cla == 'ForeignV':
return '外籍车辆'
elif cla == 'In-fieldV':
return '场内车辆'
elif cla == 'large-scaleNewenergyV':
return '大型新能源车辆'
elif cla == 'MediumLarge-sizedV':
return '中/大型车辆'
elif cla == 'MilitaryPoliceEmergencyV':
return '军/警/应急车辆'
elif cla == 'SmallCar':
return '小型轿车'
elif cla == 'SmallNewEnergyV':
return '小型新能源轿车'
else:
return '未知'
@classmethod
def TrainAI(cls, data_set_path: str, export_path: str) -> None:
blocks = (ImageBlock, CategoryBlock)
@ -39,7 +58,7 @@ class ClassificationAI:
get_y=parent_label,
item_tfms=Resize(460),
batch_tfms=[*aug_transforms(size=224, min_scale=0.75), Normalize.from_stats(*imagenet_stats)]
).dataloaders(data_set_path, num_workers=4, bs=batch_size)
).dataloaders(data_set_path, num_workers=0, bs=batch_size)
model = vision_learner(dls, resnet34, metrics=error_rate)
model.fine_tune(5, freeze_epochs=3) # 5 - 训练的轮次, 3 - 冻结的轮次
model.export(Path(export_path) / 'model.pkl')
@ -58,4 +77,5 @@ class ClassificationAI:
confidence = float(outputs)
else:
confidence = float(outputs[pred_idx])
pred_class = cls.ConvertClassifyResult(pred_class)
return pred_class, confidence

@ -22,15 +22,19 @@ def main(classify_model_index: int, image_path: str) -> None:
global classify_models
origin_image, gray_image = ImageCutter.ImagePreProcess(image_path)
lpr_text, lpr_conf, cut_image = OCR.RecognizeLicensePlate2(origin_image)
ocr_text, ocr_type = OCR.RecognizeLicensePlate(cut_image)
if cut_image is None:
cut_image = ImageCutter.CutPlateRect(origin_image, gray_image)
ocr_text, ocr_type = OCR.RecognizeLicensePlate(cut_image, lpr_text)
if lpr_text is None:
lpr_text = ocr_text
lpr_conf = None
ai_type, ai_conf = ClassificationAI.PredictImage(cut_image, classify_models[classify_model_index])
print(f'识别完成,以下为识别结果:\n车牌号:{lpr_text} [置信度:{lpr_conf}]\n车牌类型:\n\t{ocr_type}(OCR推测)\n\t{ai_type}(AI分类识别)\n\tAI识别置信度{ai_conf}')
cv2.waitKey(0)
if __name__ == '__main__':
result = input('请选择运行模式(训练(y)/识别(n)): ')
if result == 'y' or result == 'Y':
result = input('请选择运行模式(训练(t)/识别(r)): ')
if result == 't' or result == 'T':
data_path = input('输入训练集路径: ')
export_path = input('输入模型保存路径: ')
try:
@ -41,7 +45,7 @@ if __name__ == '__main__':
print('模型已成功训练')
finally:
print('训练结束')
elif result == 'n' or result == 'N':
elif result == 'r' or result == 'R':
model_index = input('选择使用的识别模型(1/2/3): ')
image_path = input('输入图片路径: ')
if (not model_index.isdigit()) or (int(model_index) < 1) or (int(model_index) > 3):

@ -19,20 +19,24 @@ class OCR:
return text
@classmethod
def RecognizeLicensePlate(cls, image: cv2.Mat) -> tuple:
def RecognizeLicensePlate(cls, image: cv2.Mat, lpr_text: str) -> tuple:
reader = easyocr.Reader(['ch_sim', 'en'], model_storage_directory='./easyocr_model')
result = reader.readtext(image)
license_plate = ""
for res in result:
license_plate += res[-2] # 如果车牌号码是两行的,按行识别出来再拼接起来
if lpr_text is not None:
license_plate = lpr_text
if '\u8b66' in license_plate:
car_type = 'police'
car_type = '警用车辆'
elif '\u573a\u5185' in license_plate:
car_type = 'internal'
car_type = '场内车辆'
elif '\u6302' in license_plate:
car_type = 'bigCar'
car_type = '挂车/半挂车'
elif len(license_plate) > 7:
car_type = '新能源车辆'
else:
car_type = 'smallCar'
car_type = '无法推测'
return license_plate, car_type
@classmethod
@ -43,4 +47,4 @@ class OCR:
x0, y0, x1, y1 = box
cut_image = image[y0:y1, x0:x1]
return code, conf, cut_image
return None, None, image
return None, None, None

Loading…
Cancel
Save