博客
关于我
Keras自定义网络进行十分类图像识别
阅读量:262 次
发布时间:2019-03-01

本文共 4177 字,大约阅读时间需要 13 分钟。

import osimport numpy as npimport tensorflow as tfimport randomimport seaborn as snsimport matplotlib.pyplot as pltfrom keras.models import Sequential, Modelfrom keras.layers import Dense, Dropout, Activation, Flatten, Inputfrom keras.layers.convolutional import Conv2D, MaxPooling2Dfrom keras.optimizers import RMSprop, Adam, SGDfrom keras.preprocessing import imagefrom keras.preprocessing.image import ImageDataGeneratorfrom keras.utils import np_utilsfrom sklearn.model_selection import train_test_split

图片预处理

def read_and_process_image(data_dir,width=32, height=32, channels=3, preprocess=False):        train_classes= [data_dir +  i for i in os.listdir(data_dir) ]    train_images = []    for train_class in train_classes:        train_images= train_images + [train_class + "/" + i for i in os.listdir(train_class)]        random.shuffle(train_images)        def read_image(file_path, preprocess):        img = image.load_img(file_path, target_size=(height, width))        x = image.img_to_array(img)        x = np.expand_dims(x, axis=0)        # if preprocess:            # x = preprocess_input(x)        return x        def prep_data(images, proprocess):        count = len(images)        data = np.ndarray((count, height, width, channels), dtype = np.float32)                for i, image_file in enumerate(images):            image = read_image(image_file, preprocess)            data[i] = image                return data        def read_labels(file_path):        labels = []        for i in file_path:            if 'airplane' in i:                label = 0            elif 'automobile' in i:                label = 1            elif 'bird' in i:                label = 2            elif 'cat' in i:                label = 3            elif 'deer' in i:                label = 4            elif 'dog' in i:                label = 5            elif 'frog' in i:                label = 6            elif 'horse' in i:                label = 7            elif 'ship' in i:                label = 8            elif 'truck' in i:                label = 9            labels.append(label)                return labels        X = prep_data(train_images, preprocess)    labels = read_labels(train_images)        assert X.shape[0] == len(labels)        print("Train shape: {}".format(X.shape))        return X, labels

读取训练集,以及测试集

# 读取训练集图片WIDTH = 32HEIGHT = 32CHANNELS = 3X, y = read_and_process_image('D:/Python Project/cifar-10/train/',width=WIDTH, height=HEIGHT, channels=CHANNELS)# 读取测试集图片WIDTH = 32HEIGHT = 32CHANNELS = 3test_X, test_y = read_and_process_image('D:/Python Project/cifar-10/test/',width=WIDTH, height=HEIGHT, channels=CHANNELS)# 统计ysns.countplot(y)# 统计ysns.countplot(test_y)

one-hot编码

train_y = np_utils.to_categorical(y)test_y = np_utils.to_categorical(test_y)

显示图片

# 显示图片def show_picture(X, idx):    plt.figure(figsize=(10,5), frameon=True)    img = X[idx,:,:,::-1]    img = img/255    plt.imshow(img)    plt.show()for idx in range(0,3):    show_picture(X, idx)

定义模型

num_classes=10model = Sequential()model.add(Conv2D(32 ,3 ,input_shape=(HEIGHT,WIDTH,CHANNELS),activation='relu',padding='same'))model.add(Conv2D(32 ,3 ,activation='relu',padding='same'))model.add(MaxPooling2D(pool_size=2))model.add(Conv2D(64 ,3 ,activation='relu',padding='same'))model.add(Conv2D(64 ,3 ,activation='relu',padding='same'))model.add(MaxPooling2D(pool_size=2))model.add(Conv2D(128 ,3 ,activation='relu',padding='same'))model.add(Conv2D(128 ,3 ,activation='relu',padding='same'))model.add(MaxPooling2D(pool_size=2))model.add(Conv2D(256 ,3 ,activation='relu',padding='same'))model.add(Conv2D(256 ,3 ,activation='relu',padding='same'))model.add(MaxPooling2D(pool_size=2))model.add(Flatten())model.add(Dense(256, activation='relu'))model.add(Dropout(0.5))model.add(Dense(256, activation='relu'))model.add(Dropout(0.5))model.add(Dense(num_classes, activation='softmax'))model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])model.summary()

训练模型

history = model.fit(X,train_y, validation_data=(test_X, test_y),epochs=20,batch_size=100,verbose=True)score = model.evaluate(test_X, test_y, verbose=0)print("Large CNN Error: %.2f%%" %(100-score[1]*100))

 

转载地址:http://kshv.baihongyu.com/

你可能感兴趣的文章
MySQL原理简介—2.InnoDB架构原理和执行流程
查看>>
MySQL原理简介—3.生产环境的部署压测
查看>>
MySQL原理简介—6.简单的生产优化案例
查看>>
MySQL原理简介—7.redo日志的底层原理
查看>>
MySQL原理简介—8.MySQL并发事务处理
查看>>
MySQL原理简介—9.MySQL索引原理
查看>>
MySQL参数调优详解
查看>>
mysql参考触发条件_MySQL 5.0-触发器(参考)_mysql
查看>>
MySQL及navicat for mysql中文乱码
查看>>
MySqL双机热备份(二)--MysqL主-主复制实现
查看>>
MySQL各个版本区别及问题总结
查看>>
MySql各种查询
查看>>
mysql同主机下 复制一个数据库所有文件到另一个数据库
查看>>
mysql启动以后会自动关闭_驾照虽然是C1,一直是开自动挡的车,会不会以后就不会开手动了?...
查看>>
mysql启动和关闭外键约束的方法(FOREIGN_KEY_CHECKS)
查看>>
Mysql启动失败解决过程
查看>>
MySQL启动失败:Can't start server: Bind on TCP/IP port
查看>>
mysql启动报错
查看>>
mysql启动报错The server quit without updating PID file几种解决办法
查看>>
mysql命令
查看>>