""" 模块作者: AI代码结构:刘钰廷、冯雅君 代码优化整理:王昱博、冯昌盛 AI模型训练/纠错:刘钰廷、冯雅君、冯昌盛 代码整合/打包:王昱博 模块用途: 图像分类AI,用于区分车牌的具体类型 """ import cv2 from PIL import Image from pathlib import Path from fastai.vision.all import * from fastai.metrics import error_rate from fastai.learner import load_learner from torchvision.models import resnet34 from fastai.vision.data import ImageBlock from fastai.vision.core import imagenet_stats from fastai.data.block import CategoryBlock, DataBlock from fastai.vision.augment import Resize, aug_transforms from fastai.vision.learner import cnn_learner, vision_learner from fastai.data.transforms import get_image_files, parent_label, RandomSplitter, Normalize class ClassificationAI: @staticmethod def ConvertImage(cv_img: cv2.Mat) -> Image.Image: return Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)).resize((460, 460)) @staticmethod def ConvertClassifyResult(cla: str) -> str: if cla == 'ForeignV': return '外籍车辆' elif cla == 'In-fieldV': return '场内车辆' elif cla == 'large-scaleNewenergyV': return '大型新能源车辆' elif cla == 'MediumLarge-sizedV': return '中/大型车辆' elif cla == 'MilitaryPoliceEmergencyV': return '军/警/应急车辆' elif cla == 'SmallCar': return '小型轿车' elif cla == 'SmallNewEnergyV': return '小型新能源轿车' else: return '未知' @classmethod def TrainAI(cls, data_set_path: str, export_path: str) -> None: blocks = (ImageBlock, CategoryBlock) batch_size = 32 dls = DataBlock( blocks=blocks, get_items=get_image_files, splitter=RandomSplitter(), get_y=parent_label, item_tfms=Resize(460), batch_tfms=[*aug_transforms(size=224, min_scale=0.75), Normalize.from_stats(*imagenet_stats)] ).dataloaders(data_set_path, num_workers=0, bs=batch_size) model = vision_learner(dls, resnet34, metrics=error_rate) model.fine_tune(5, freeze_epochs=3) # 5 - 训练的轮次, 3 - 冻结的轮次 model.export(Path(export_path) / 'model.pkl') @classmethod def PredictImage(cls, image: cv2.Mat, model_path: str) -> tuple: # 加载模型 model = load_learner(model_path) # 读取图片并转换为Tensor img = cls.ConvertImage(image) # 读取图像文件 # 进行预测 pred_class, pred_idx, outputs = model.predict(img) # 获取置信度 # 检查输出张量的维度 if outputs.dim() == 0: confidence = float(outputs) else: confidence = float(outputs[pred_idx]) pred_class = cls.ConvertClassifyResult(pred_class) return pred_class, confidence