ARJ_DenseNet BMR模型训练

废话不多数,模型训练代码

densenet_arj_BMR.py

import time

from tensorflow.keras.applications.xception import Xception
from tensorflow.keras.applications.densenet import DenseNet169
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow.keras as keras
from arj_t.plt_graph import show_graph
from common_para import train_dir, val_dir, station, EPOCHS_1,EPOCHS_2, batch_size, CLASS_WEIGHT, classes

input_shape = (224, 224)
date_ = time.strftime('%Y%m%d', time.localtime())
cpkt_path = f'./ckpt/ARJ_Densenet_ckpt{station}_20231017-1.h5'
model_path = f'./ckpt/ARJ_Densenet_MODEL{station}_{date_}.h5'


class ArjDensenetModel(object):

    def __init__(self):
        self.base_model = DenseNet169(weights='imagenet', include_top=False)

        # 泛化能力不行,进行图像增强测试
        self.train_gen = ImageDataGenerator(rescale=1.0 / 255.0,
                                            # rotation_range=45,
                                            # width_shift_range=0.2,
                                            # height_shift_range=0.2,
                                            # brightness_range=(0, 0.3),
                                            # shear_range=0.2, # 浮点数。剪切强度(以弧度逆时针方向剪切角度)
                                            # zoom_range=[0.5, 1.5],  # 小于1.0的缩放将放大图像,大于1.0的缩放将缩小图像。
                                            # horizontal_flip=True,
                                            # vertical_flip=True,
                                            # fill_mode='constant',
                                            # cval=0
                                            )

        # self.train_gen = ImageDataGenerator(rescale=1.0 / 255.0)
        self.val_gen = ImageDataGenerator(rescale=1.0 / 255.0)

    # 获取本地训练和验证图片,生成generator
    def get_local_data(self):
        self.train_gen = self.train_gen.flow_from_directory(
            directory=train_dir,
            target_size=input_shape,
            batch_size=batch_size,
            class_mode='binary',  # binary 改为 categorical
            shuffle=True,
            # save_to_dir=r'D:\AOI Gray Image-OA\dataset\BMR\train_trans2',
            # save_format='jpg',
            # save_prefix='trans_'
        )
        self.val_gen = self.val_gen.flow_from_directory(
            directory=val_dir,
            target_size=input_shape,
            batch_size=batch_size,
            class_mode='binary',  # binary 改为 categorical 2022/5/15
            shuffle=True
        )
        return None

    def refine_basemode(self):
        """
        获取VGG16 basemode
        只获取全连接层以前的卷积和池化层,并进行参数冻结,也就是使用原有训练好的参数
        自主增加隐藏层和全连接层进行训练,获得目标模型
        :return:
        """
        # 获取除全连接层以外的层数,no-top model
        x = self.base_model.outputs[0]
        # 加入全局池化、隐藏层、全连接层
        x = keras.layers.GlobalAveragePooling2D()(x)
        x = keras.layers.Dense(2048, activation='relu')(x)
        # x = keras.layers.BatchNormalization()(x)
        x = keras.layers.Dense(1024, activation='relu')(x)
        out = keras.layers.Dense(2, activation='softmax')(x)

        # 生成新的模型
        new_model = keras.models.Model(inputs=self.base_model.inputs, outputs=out)

        # 冻结vgg模型原有参数
        self.freeze_base_model()

        # 对new_model进行编译
        # 学习效果不佳,初始学习率加大尝试
        # 初始学习率0.01->0.001
        opt = keras.optimizers.Adam(learning_rate=0.001)
        new_model.compile(
            # optimizer=opt,  # 优化器
            # # 因为class_mode使用了categorical, 此时返回one-hot编码标签
            # # 那么这里就需要使用categorical_crossentropy,多类对数交叉熵损失计算
            # # 如果class_mode使用binary, 此时返回1D的二值标签,loss就需要使用sparse_categorical_crossentropy
            # loss='sparse_categorical_crossentropy',  # 使用交叉熵损失函数 分类
            # metrics=['accuracy']
            # binary_crossentropy与sigmoid联合使用二分类
            # categorical_crossentropy与softmax联合使用
            optimizer='adam',
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy']
        )

        return new_model

    # 冻结模型训练层数
    def freeze_base_model(self):
        for layer in self.base_model.layers:
            layer.trainable = False
        return None

    # # 对new_model进行training
    def fit(self, model):
        # 获取本地数据
        self.get_local_data()

        # 定义checkpoint
        ckpt = keras.callbacks.ModelCheckpoint(
            filepath=cpkt_path,
            monitor='val_accuracy',
            save_freq='epoch',
            save_weights_only=True,
            save_best_only=True
        )

        # 早停法用起来
        el1 = keras.callbacks.EarlyStopping(
            monitor='val_accuracy',
            patience=15,
            verbose=2,
            mode='auto'
        )

        # 定义学习率缩小规则
        rc1 = keras.callbacks.ReduceLROnPlateau(
            monitor='val_accuracy',
            factor=0.1,  # 学习率缩小倍数 new_lr = lr*factor
            patience=5,  # 耐心吗,5次迭代不增加就缩小学习率
            mode='auto',
            verbose=1,  # 1代表更新信息,0代表不更新
            # epsilon=0.0001,  # 确认是否进入平原区
            min_lr=0,
            cooldown=0
        )

        # 模型训练
        # 加入class_weight权重

        # 暂时注释。
        his1 = model.fit(self.train_gen, validation_data=self.val_gen,
                         epochs=EPOCHS_1, callbacks=[ckpt, rc1, el1])

        # his1 = model.fit(self.train_gen, validation_data=self.val_gen,
        #                  epochs=EPOCHS_1, callbacks=[ckpt, rc1, el1], class_weight=CLASS_WEIGHT)

        print('first step end')

        # 解冻所有layer,进行参数微调
        for layer in model.layers:
            layer.trainable = True

        # 早停法用起来
        el2 = keras.callbacks.EarlyStopping(
            monitor='val_accuracy',
            patience=11,
            verbose=2,
            mode='auto'
        )

        # 定义学习率缩小规则
        rc2 = keras.callbacks.ReduceLROnPlateau(
            monitor='val_accuracy',
            factor=0.1,  # 学习率缩小倍数 new_lr = lr*factor
            patience=5,  # 耐心吗,5次迭代不增加就缩小学习率
            mode='auto',
            verbose=1,  # 1代表更新信息,0代表不更新
            # epsilon=0.0001,  # 确认是否进入平原区
            min_lr=0,
            cooldown=0
        )

        opt = keras.optimizers.Adam(learning_rate=0.001)
        model.compile(
            optimizer=opt,
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy']
        )

        # 模型训练
        # model.load_weights(cpkt_path)
        his2 = model.fit(self.train_gen, validation_data=self.val_gen,
                         epochs=EPOCHS_2, callbacks=[ckpt, rc2, el2], class_weight={0: 1, 1: 1.5})

        # # 模型训练
        # his2 = model.fit(self.train_gen, validation_data=self.val_gen,
        #                  epochs=EPOCHS_2, callbacks=[ckpt, rc2, el2], class_weight={0: 1, 1: 2, 2: 3})

        print('END STEP')
        return his1, his2


if __name__ == '__main__':
    arj_model = ArjDensenetModel()
    model = arj_model.refine_basemode()
    his1, his2 = arj_model.fit(model)
    # # 保存模型
    # model.save(model_path)
    show_graph(his1)
    show_graph(his2)

common_para.py代码

train_dir = r"D:\new_data\BMR_TRAIN\train"
val_dir = r"D:\new_data\BMR_TRAIN\validate"
station = '_ALL_BMR'
batch_size = 32
EPOCHS_1 = 10
EPOCHS_2 = 40
CLASS_WEIGHT = {0: 1., 1: 1., 2: 1.}
threshold_value = 0
classes = 2

模型预测代码 

BMR_IPS_135K_predict.py
import os

import numpy as np
from tensorflow.keras.preprocessing.image import load_img, img_to_array

import densenet_arj_BMR
import inceptionRestnet_arj_t
import resnet101_arj_BMR
import xception_arj_BMR

MODEL_NAME = 'densenet'

val_path = r'D:\AOI Gray Image-OA\dataset\case1\135K-ISR-IPS\validate'
# val_path = r'D:\AOI Gray Image-OA\dataset\error\W_to_G'
other_path = r'D:\AOI Gray Image-OA\AOI IMAGE-20220513\A6Q\ISR\复判后-G'
test_path = r'D:\new_data\BMR表外测试\BMR\A1A\P'
# TARGET_SIZE = (299, 299)
# DEFECT_TYPE = 'P'
# error_path = fr"D:\AOI Gray Image-OA\dataset\BMR\{MODEL_NAME}"


def get_ckptpath_model():
    # arj_model = tensorflow.keras.models.Model()
    ckpt_path = ''
    # target_size = (224, 224)
    if MODEL_NAME == 'xception':
        ckpt_path = xception_arj_BMR.cpkt_path
        arj_model = xception_arj_BMR.ArjResnet101Model()
        target_size = xception_arj_BMR.input_shape
    elif MODEL_NAME == 'inceptionRestnet':
        ckpt_path = inceptionRestnet_arj_t.cpkt_path
        arj_model = inceptionRestnet_arj_t.ArjInceptionRestnetModel()
        target_size = inceptionRestnet_arj_t.input_shape
    elif MODEL_NAME == 'densenet':
        ckpt_path = densenet_arj_BMR.cpkt_path
        arj_model = densenet_arj_BMR.ArjDensenetModel()
        target_size = densenet_arj_BMR.input_shape
    elif MODEL_NAME == 'resnet101':
        ckpt_path = resnet101_arj_BMR.cpkt_path
        arj_model = resnet101_arj_BMR.ArjResnet101Model()
        target_size = resnet101_arj_BMR.input_shape

    return ckpt_path, arj_model, target_size


# 获取想要预测的图片绝对路径,包含文件名
def get_img_paths(defect_type, path):
    img_path = os.path.join(path, defect_type)
    img_paths = []
    for root, dirs, files in os.walk(img_path):
        for file in files:
            # print(file[-3:])
            if file[-3:] == 'jpg':
                img_paths.append(os.path.join(root, file))
    return img_paths


def bmr_ips_predict(img_paths, error_path, defect_type='G'):
    ckpt_path, arj_model, input_shape = get_ckptpath_model()
    model = arj_model.refine_basemode()
    print(ckpt_path)
    model.load_weights(ckpt_path)
    print(model.summary())

    predict_dict = {0: 'G', 1: 'P', 2: 'W'}

    # 加载图片,预测
    white_cnt = 0
    good_cnt = 0
    repair_cnt = 0

    threshold_ls = []
    for img_path in img_paths:
        img_arr = load_img(img_path, target_size=input_shape)
        img = img_arr
        # print(img_path)
        # 转化为矩阵
        img_arr = img_to_array(img_arr)
        # print(img.shape)
        # 归一化
        # img_arr = preprocess_input(img_arr)

        img_arr /= 255.
        # print(type(img_arr))
        # img_arr = preprocess_input(img_arr)

        # img_arr /= 127.5
        # img_arr -= 1.


        # 形状修改
        img_arr = img_arr.reshape(1, img_arr.shape[0], img_arr.shape[1], img_arr.shape[2])
        # print(img.shape)
        # print(img_arr)

        y_predict = model.predict(img_arr)

        index = np.argmax(y_predict)
        # 加入阈值
        threshold = y_predict[0][index]
        # print(img_path.split('\\')[-1])
        # print(y_predict[0], ' >> ', threshold)
        # threshold_ls.append(threshold)
        # print(y_predict)
        y_predict = predict_dict[index]

        # print(index)
        # print(y_predict)

        # if index == 0:
        #     good_cnt += 1
        # else:
        #     repair_cnt += 1

        # 保存判错的图片
        # 预测结果G

        # save_img_name = str(round(threshold,2))+'_'+img_path.split('\\')[-1]
        save_img_name = img_path.split('\\')[-1]
        if index == 0:
            # 加入阈值调节判G能力
            if threshold > 0:
                good_cnt += 1
                # print(good_cnt)
                # print(img_path[-10:])
                # 如果原本P文件夹
                if defect_type == 'P':
                    threshold_ls.append(threshold)
                    img.save(os.path.join(error_path, 'AI_P_TO_G', save_img_name))
                    os.remove(img_path)
                # 如果原本W文件夹
                if defect_type == 'W':
                    img.save(os.path.join(error_path, 'AI_W_TO_G', save_img_name))
                    os.remove(img_path)
            else:
                repair_cnt += 1
        elif index == 1:
            repair_cnt += 1
            if defect_type == 'G':
                threshold_ls.append(threshold)
                img.save(os.path.join(error_path, 'AI_G_TO_P', save_img_name))
                os.remove(img_path)
            if defect_type == 'W':
                img.save(os.path.join(error_path, 'AI_W_TO_P', save_img_name))
                os.remove(img_path)
        elif index == 2:
            white_cnt += 1
            if defect_type == 'G':
                img.save(os.path.join(error_path, 'AI_G_TO_W', save_img_name))
                os.remove(img_path)
            if defect_type == 'P':
                img.save(os.path.join(error_path, 'AI_P_TO_W', save_img_name))
                os.remove(img_path)
        else:
            print('还有第四种可能??!!')
        # print(y_predict)
        # print('**************************')

    # pd.DataFrame(data=threshold_ls).to_csv('./threshold.csv', encoding='utf-8')
    print(threshold_ls)
    print('good_cnt :  %d' % good_cnt)
    print('repair_cnt :  %d' % repair_cnt)
    # print('white_cnt :  %d' % white_cnt)


# if __name__ == '__main__':
#     paths = get_img_paths(DEFECT_TYPE, test_path)
#     bmr_ips_predict(paths,error_path)

模型总预测代码

predict_all.py

import BMR_IPS_135K_predict

# 此程序用来进行所有模型预测2023/10/17

img_path = r'D:\new_data\BMR_TRAIN\test\WHITE'

DEFECT_TYPE = 'P'

paths = BMR_IPS_135K_predict.get_img_paths(DEFECT_TYPE, img_path)

BMR_IPS_135K_predict.bmr_ips_predict(paths, img_path, DEFECT_TYPE)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值