博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
微调Inception V3网络-对Satellite分类
阅读量:5942 次
发布时间:2019-06-19

本文共 11099 字,大约阅读时间需要 36 分钟。

目录

  这篇博客主要是使用Keras框架微调Inception V3模型对卫星图片进行分类,并测试;

1. 流程概述

  微调Inception V3对卫星图片进行分类;整个流程可以大致分成四个步骤,如下:

  • (1)Satellite数据集准备;
  • (2)搭建Inception V3网络;
  • (3)进行训练;
  • (4)测试;

2. 准备数据集

2.1 Satellite数据集介绍

  用于实验训练与测试的数据集来自于中提供的实验卫星图片数据集;

  Satellite数据集目录结构如下:

# 其中共6类卫星图片,训练集总共4800张,每类800张;验证集共1200张,每类200张;Satellite/    train/          glacier/        rock/        urban/        water/        wetland/        wood/    validation/          glacier/        rock/        urban/        water/        wetland/        wood/

3. Inception V3网络

  待补充;

4. 训练

4.1 基于Keras微调Inception V3网络

from keras.application.incepiton_v3 import InceptionV3, preprocess_inputfrom keras.layers import GlobalAveragePooling2D, Dense#  基础Inception_V3模型,不包含全连接层base_model = InceptionV3(weights='imagenet', include_top=False)#  增加新的输出层x = base_model.outputx = GlobalAveragePooling2D()(x) # 添加Global average pooling层x = Dense(1024, activation='relu')(x)predictions = Dense(6, activation='softmax')(x)

4.2 Keras实时生成批量增强数据

# keras实时生成批量增强数据train_datagen = ImageDataGenerator(    preprocessing_function=preprocess_input,  # 将每一张图片归一化到[-1,1];数据增强后执行;    rotation_range=30,    width_shift_range=0.2,    height_shift_range=0.2,    shear_range=0.2,    zoom_range=0.2,    horizontal_flip=True,)val_datagen = ImageDataGenerator(    preprocessing_function=preprocess_input,     rotation_range=30,    width_shift_range=0.2,    height_shift_range=0.2,    shear_range=0.2,    zoom_range=0.2,    horizontal_flip=True,)#  指定数据集路径并批量生成增强数据train_generator = train_datagen.flow_from_directory(directory='satellite/data/train',                                  target_size=(299, 299),#Inception V3规定大小                                  batch_size=64)val_generator = val_datagen.flow_from_directory(directory='satellite/data/validation',                                target_size=(299,299),                                batch_size=64)

4.3 配置transfer learning & finetune

from keras.optimizers import Adagrad# transfer learningdef setup_to_transfer_learning(model,base_model):#base_model    for layer in base_model.layers:        layer.trainable = False    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])  # 配置模型,为下一步训练  # finetunedef setup_to_fine_tune(model,base_model):    GAP_LAYER = 17  # max_pooling_2d_2    for layer in base_model.layers[:GAP_LAYER+1]:        layer.trainable = False    for layer in base_model.layers[GAP_LAYER+1:]:        layer.trainable = True    model.compile(optimizer=Adagrad(lr=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])

4.4 执行训练

# Step 1: transfer learningsetup_to_transfer_learning(model,base_model)history_tl = model.fit_generator(generator=train_generator,                    steps_per_epoch=75,  # 800                    epochs=10,                    validation_data=val_generator,                    validation_steps=64,  # 12                    class_weight='auto'                    )model.save('satellite/train_dir/satellite_iv3_tl.h5')# Step 2: finetunesetup_to_fine_tune(model,base_model)history_ft = model.fit_generator(generator=train_generator,                                 steps_per_epoch=75,                                 epochs=10,                                 validation_data=val_generator,                                 validation_steps=64,                                 class_weight='auto')model.save('satellite/train_dir/satellite_iv3_ft.h5')

5. 测试

5.1 对单张图片进行测试

# *-coding: utf-8 -*"""使用h5模型文件对satellite进行测试"""# ================================================================import tensorflow as tfimport numpy as npfrom skimage import iofrom keras.models import load_modeldef normalize(array):    """对给定数组进行归一化    Argument:        array: array            给定数组    Return:        array_norm: array            归一化后的数组    """    array_flatten = array.flatten()    array_mean = np.mean(array_flatten)    mx = np.max(array_flatten)    mn = np.min(array_flatten)    array_norm = [(float(i) - array_mean) / (mx - mn) for i in array_flatten]    return np.reshape(array_norm, array.shape)def img_preprocess(image_path):    """根据图片路径,对图片进行相应预处理    Argument:        image_path: str            输入图片路径    Return:        image_data: array            预处理好的图像数组    """    img_array = io.imread(image_path)    img_norm = normalize(img_array)    size = img_norm.shape    image_data = np.reshape(img_norm, (1, size[0], size[1], 3))    return image_datadef index_to_label(index):    """将标签索引转换成可读的标签    Argument:        index: int            标签索引位置    Return:        human_label: str            人可读的标签    """    labels = ["glacier", "rock", "urban", "water", "wetland", "wood"]    human_label = labels[index]    return human_labeldef classifier_satellite_byh5(image_path, model_file_path):    """对给定单张图片使用训练好的模型进行分类    Argument:        image_path: str            输入图片路径        model_file_path: str            训练好的h5模型文件名称    Return:        human_label: str            人可读的图片标签    """    image_data = img_preprocess(image_path)    # 加载模型文件    model = load_model(model_file_path)    predictions = model.predict(image_data)    human_label = index_to_label(np.argmax(predictions))    return human_labeldef classifier_satellite_byh5_hci(image_path):    """用于对从交互界面传来的图片进行分类    Argument:        image_path: str    Return:        human_label: str            人可读的图片标签    """    # 模型文件,如果有新的模型需要修改    model_file_path = "satellite/train_dir/models/satellite_iv3_ft.h5"    image_data = img_preprocess(image_path)    # 加载模型文件    model = load_model(model_file_path)    predictions = model.predict(image_data)    human_label = index_to_label(np.argmax(predictions))    return human_label# 测试单张图片if __name__ == "__main__":    image_path = "satellite/data/train/glacier/40965_91335_18.jpg"    model_file_path = "satellite/train_dir/models/satellite_iv3_ft.h5"    human_label = classifier_satellite_byh5(image_path, model_file_path)    print(human_label)

6. 可视化分类界面

6.1 交互界面设计

# encoding: utf-8"""交互界面:使用训练好的模型对卫星图片进行分类;"""from tkinter import *import tkinterimport tkinter.filedialogimport osimport tkinter.messageboxfrom PIL import Image, ImageTkimport test_satellite_bypb# 窗口属性root = tkinter.Tk()root.title('Satellite图像分类')root.geometry('800x600')formatImg = ['jpg']def resize(w, h, w_box, h_box, pil_image):  # 对一个pil_image对象进行缩放,让它在一个矩形框内,还能保持比例  f1 = 1.0*w_box/w # 1.0 forces float division in Python2  f2 = 1.0*h_box/h  factor = min([f1, f2])  width = int(w*factor)  height = int(h*factor)  return pil_image.resize((width, height), Image.ANTIALIAS)def showImg():    img1 = entry_imgPath.get()  # 获取图片路径地址    pil_image = Image.open(img1)    # 打开图片    # 期望显示大小    w_box = 400    h_box = 400    # 获取原始图像的大小    w, h = pil_image.size    pil_image_resized = resize(w, h, w_box, h_box, pil_image)    # 把PIL图像对象转变为Tkinter的PhotoImage对象    tk_image = ImageTk.PhotoImage(pil_image_resized)    img = tkinter.Label(image=tk_image, width=w_box, height=h_box)    img.image = tk_image    img.place(x=50, y=150)def choose_file():    text_showClass.delete(0.0, END) # 清空输出结果文本框,在再次选择图片文件之前清空上次结果;    selectFileName = tkinter.filedialog.askopenfilename(title='选择文件')  # 选择文件    if selectFileName[-3:] not in formatImg:        tkinter.messagebox.askokcancel(title='出错', message='未选择图片或图片格式不正确')   # 弹出错误窗口        return    else:        e.set(selectFileName)  # 设置变量        showImg()   # 显示图片def ouputOfModel():    # 完成识别,显示类别    # 图片文件路径    text_showClass.delete(0.0, END) # 清空上次结果文本框    img_path = entry_imgPath.get()  # 获取所选择的图片路径地址    # 判断是否存在改图片    if not os.path.exists(img_path):        tkinter.messagebox.askokcancel(title='出错', message='未选择图片文件或图片格式不正确')    else:        # 得到输出结果,以及相应概率        human_label = test_satellite_bypb.classifier_satellite_img(img_path)        # 通过训练的模型,计算得到相对应输出类别        # 清空文本框中的内容,写入识别出来的类别        text_showClass.config(state=NORMAL)        text_showClass.insert('insert', '%s\n' % (human_label))################### 窗口部件##################e = tkinter.StringVar() # 字符串变量# label : 选择文件label_selectImg = tkinter.Label(root, text='选择图片:')label_selectImg.grid(row=0, column=0)# Entry: 显示图片文件路径地址entry_imgPath = tkinter.Entry(root, width=80, textvariable=e)entry_imgPath.grid(row=0, column=1)# Button: 选择图片文件button_selectImg = tkinter.Button(root, text="选择", command=choose_file)button_selectImg.grid(row=0, column=2)# Button: 执行识别程序按钮button_recogImg = tkinter.Button(root, text="开始识别", command=ouputOfModel)button_recogImg.grid(row=0, column=3)# Text: 显示结果类别文本框text_showClass = tkinter.Text(root, width=20, height=1, font='18',)text_showClass.grid(row=1, column=1)text_showClass.config(state=DISABLED)root.mainloop()

6.2 后台核心代码:模型加载并分类

# *-coding: utf-8 -*"""使用h5模型文件对satellite进行测试"""# ================================================================import tensorflow as tfimport numpy as npfrom skimage import iofrom keras.models import load_modeldef normalize(array):    """对给定数组进行归一化    Argument:        array: array            给定数组    Return:        array_norm: array            归一化后的数组    """    array_flatten = array.flatten()    array_mean = np.mean(array_flatten)    mx = np.max(array_flatten)    mn = np.min(array_flatten)    array_norm = [(float(i) - array_mean) / (mx - mn) for i in array_flatten]    return np.reshape(array_norm, array.shape)def img_preprocess(image_path):    """根据图片路径,对图片进行相应预处理    Argument:        image_path: str            输入图片路径    Return:        image_data: array            预处理好的图像数组    """    img_array = io.imread(image_path)    img_norm = normalize(img_array)    size = img_norm.shape    image_data = np.reshape(img_norm, (1, size[0], size[1], 3))    return image_datadef index_to_label(index):    """将标签索引转换成可读的标签    Argument:        index: int            标签索引位置    Return:        human_label: str            人可读的标签    """    labels = ["glacier", "rock", "urban", "water", "wetland", "wood"]    human_label = labels[index]    return human_labeldef classifier_satellite_byh5(image_path, model_file_path):    """对给定单张图片使用训练好的模型进行分类    Argument:        image_path: str            输入图片路径        model_file_path: str            训练好的h5模型文件名称    Return:        human_label: str            人可读的图片标签    """    image_data = img_preprocess(image_path)    # 加载模型文件    model = load_model(model_file_path)    predictions = model.predict(image_data)    human_label = index_to_label(np.argmax(predictions))    return human_labeldef classifier_satellite_byh5_hci(image_path):    """用于对从交互界面传来的图片进行分类    Argument:        image_path: str    Return:        human_label: str            人可读的图片标签    """    # 模型文件,如果有新的模型需要修改    model_file_path = "satellite/train_dir/models/satellite_iv3_ft.h5"    image_data = img_preprocess(image_path)    # 加载模型文件    model = load_model(model_file_path)    predictions = model.predict(image_data)    human_label = index_to_label(np.argmax(predictions))    return human_label# 测试单张图片if __name__ == "__main__":    image_path = "satellite/data/train/glacier/40965_91335_18.jpg"    model_file_path = "satellite/train_dir/models/satellite_iv3_ft.h5"    human_label = classifier_satellite_byh5(image_path, model_file_path)    print(human_label)

6.3 交互界面效果

1.png

转载于:https://www.cnblogs.com/chenzhen0530/p/10686178.html

你可能感兴趣的文章
别再傻傻地用这些软件G转P了,修复后不稳定的真相在这里
查看>>
linux环境下的小练习
查看>>
ansible 介绍
查看>>
QQ的账号登录及api操作
查看>>
python 在内存中读写:StringIO / BytesIO
查看>>
jquery判断当前设备是手机还是电脑并跳转
查看>>
简练软考知识点整理-激励理论之XY理论
查看>>
微会动微信现场互动:2019年会展和活动产业的发展趋势
查看>>
java架构师,必须掌握的几点技术?
查看>>
iis
查看>>
Linux之systemctl命令的使用
查看>>
Java程序员从阿里面试回来,这些面试题你们会吗?
查看>>
文件管理小知识
查看>>
2018-11-05直播
查看>>
一个锁等待现象的诊断案例
查看>>
代理流程
查看>>
反向区域DNS解析服务
查看>>
怎么注册今日头条?哪里可以直接购置?
查看>>
各类操作系统的TTL字段值
查看>>
RabbitMQ分布式集群架构
查看>>