From 2f97072a550d21f4375c04baa31348d9b3938c43 Mon Sep 17 00:00:00 2001 From: UnknownObject Date: Wed, 6 Dec 2023 15:31:04 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=BA=86=E6=9B=B4=E8=AF=A1?= =?UTF-8?q?=E5=BC=82=E7=9A=84BUG=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ver3.py | 64 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 ver3.py diff --git a/ver3.py b/ver3.py new file mode 100644 index 0000000..aa284c0 --- /dev/null +++ b/ver3.py @@ -0,0 +1,64 @@ +from fastai.data.transforms import get_image_files, parent_label, RandomSplitter, Normalize +from fastai.learner import load_learner +from fastai.metrics import error_rate +from pathlib import Path +from fastai.data.block import CategoryBlock, DataBlock +from fastai.vision.all import * +from fastai.vision.augment import Resize, aug_transforms +from fastai.vision.core import imagenet_stats +from fastai.vision.data import ImageBlock +from fastai.vision.learner import cnn_learner, vision_learner +from torchvision.models import resnet34 +from PIL import Image + + +def open_image(image_path): + img = Image.open(image_path) + img_cvt = img.resize((460,460)) + return img + + +def predict_image(image_path): + # 加载模型 + model = load_learner('G:\\Users\\15819\\Desktop\\model01.pkl') + # 读取图片并转换为Tensor + img = open_image(image_path) # 读取指定路径(image_path)下的图像文件 + # 进行预测 + pred_class, pred_idx, outputs = model.predict(img) + # 获取置信度 + # 检查输出张量的维度 + if outputs.dim() == 0: + confidence = float(outputs) + else: + confidence = float(outputs[pred_idx]) + + return pred_class, confidence + +def train(): + data_path = Path('G:\\Users\\15819\\Desktop\\Images2') + export_path = Path('G:\\Users\\15819\\Desktop') + blocks = (ImageBlock, CategoryBlock) + batch_size = 32 + dls = DataBlock( + blocks=blocks, + get_items=get_image_files, + splitter=RandomSplitter(), + get_y=parent_label, + item_tfms=Resize(460), + batch_tfms=[*aug_transforms(size=224, min_scale=0.75), Normalize.from_stats(*imagenet_stats)] + ).dataloaders(data_path, num_workers=4, bs=batch_size) + + model = vision_learner(dls, resnet34, metrics=error_rate) + model.fine_tune(5, freeze_epochs=3) + model.export('G:\\Users\\15819\\Desktop\\model01.pkl') + +def main(): + #train() + + image_path = 'G:\\Users\\15819\\Desktop\\Images2\\SmallCar\\京M88888.jpg' + pred_class, confidence = predict_image(image_path) + print(f"图片类别: {pred_class}, 置信度: {confidence}") + + +if __name__ == '__main__': + main()