博客
关于我
Keras自定义网络进行十分类图像识别
阅读量:262 次
发布时间:2019-03-01

本文共 4177 字,大约阅读时间需要 13 分钟。

import osimport numpy as npimport tensorflow as tfimport randomimport seaborn as snsimport matplotlib.pyplot as pltfrom keras.models import Sequential, Modelfrom keras.layers import Dense, Dropout, Activation, Flatten, Inputfrom keras.layers.convolutional import Conv2D, MaxPooling2Dfrom keras.optimizers import RMSprop, Adam, SGDfrom keras.preprocessing import imagefrom keras.preprocessing.image import ImageDataGeneratorfrom keras.utils import np_utilsfrom sklearn.model_selection import train_test_split

图片预处理

def read_and_process_image(data_dir,width=32, height=32, channels=3, preprocess=False):        train_classes= [data_dir +  i for i in os.listdir(data_dir) ]    train_images = []    for train_class in train_classes:        train_images= train_images + [train_class + "/" + i for i in os.listdir(train_class)]        random.shuffle(train_images)        def read_image(file_path, preprocess):        img = image.load_img(file_path, target_size=(height, width))        x = image.img_to_array(img)        x = np.expand_dims(x, axis=0)        # if preprocess:            # x = preprocess_input(x)        return x        def prep_data(images, proprocess):        count = len(images)        data = np.ndarray((count, height, width, channels), dtype = np.float32)                for i, image_file in enumerate(images):            image = read_image(image_file, preprocess)            data[i] = image                return data        def read_labels(file_path):        labels = []        for i in file_path:            if 'airplane' in i:                label = 0            elif 'automobile' in i:                label = 1            elif 'bird' in i:                label = 2            elif 'cat' in i:                label = 3            elif 'deer' in i:                label = 4            elif 'dog' in i:                label = 5            elif 'frog' in i:                label = 6            elif 'horse' in i:                label = 7            elif 'ship' in i:                label = 8            elif 'truck' in i:                label = 9            labels.append(label)                return labels        X = prep_data(train_images, preprocess)    labels = read_labels(train_images)        assert X.shape[0] == len(labels)        print("Train shape: {}".format(X.shape))        return X, labels

读取训练集,以及测试集

# 读取训练集图片WIDTH = 32HEIGHT = 32CHANNELS = 3X, y = read_and_process_image('D:/Python Project/cifar-10/train/',width=WIDTH, height=HEIGHT, channels=CHANNELS)# 读取测试集图片WIDTH = 32HEIGHT = 32CHANNELS = 3test_X, test_y = read_and_process_image('D:/Python Project/cifar-10/test/',width=WIDTH, height=HEIGHT, channels=CHANNELS)# 统计ysns.countplot(y)# 统计ysns.countplot(test_y)

one-hot编码

train_y = np_utils.to_categorical(y)test_y = np_utils.to_categorical(test_y)

显示图片

# 显示图片def show_picture(X, idx):    plt.figure(figsize=(10,5), frameon=True)    img = X[idx,:,:,::-1]    img = img/255    plt.imshow(img)    plt.show()for idx in range(0,3):    show_picture(X, idx)

定义模型

num_classes=10model = Sequential()model.add(Conv2D(32 ,3 ,input_shape=(HEIGHT,WIDTH,CHANNELS),activation='relu',padding='same'))model.add(Conv2D(32 ,3 ,activation='relu',padding='same'))model.add(MaxPooling2D(pool_size=2))model.add(Conv2D(64 ,3 ,activation='relu',padding='same'))model.add(Conv2D(64 ,3 ,activation='relu',padding='same'))model.add(MaxPooling2D(pool_size=2))model.add(Conv2D(128 ,3 ,activation='relu',padding='same'))model.add(Conv2D(128 ,3 ,activation='relu',padding='same'))model.add(MaxPooling2D(pool_size=2))model.add(Conv2D(256 ,3 ,activation='relu',padding='same'))model.add(Conv2D(256 ,3 ,activation='relu',padding='same'))model.add(MaxPooling2D(pool_size=2))model.add(Flatten())model.add(Dense(256, activation='relu'))model.add(Dropout(0.5))model.add(Dense(256, activation='relu'))model.add(Dropout(0.5))model.add(Dense(num_classes, activation='softmax'))model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])model.summary()

训练模型

history = model.fit(X,train_y, validation_data=(test_X, test_y),epochs=20,batch_size=100,verbose=True)score = model.evaluate(test_X, test_y, verbose=0)print("Large CNN Error: %.2f%%" %(100-score[1]*100))

 

转载地址:http://kshv.baihongyu.com/

你可能感兴趣的文章
Mysql DBA 高级运维学习之路-DQL语句之select知识讲解
查看>>
mysql deadlock found when trying to get lock暴力解决
查看>>
MuseTalk如何生成高质量视频(使用技巧)
查看>>
mutiplemap 总结
查看>>
MySQL DELETE 表别名问题
查看>>
MySQL Error Handling in Stored Procedures---转载
查看>>
MVC 区域功能
查看>>
MySQL FEDERATED 提示
查看>>
mysql generic安装_MySQL 5.6 Generic Binary安装与配置_MySQL
查看>>
Mysql group by
查看>>
MySQL I 有福啦,窗口函数大大提高了取数的效率!
查看>>
mysql id自动增长 初始值 Mysql重置auto_increment初始值
查看>>
MySQL in 太多过慢的 3 种解决方案
查看>>
MySQL InnoDB 三大文件日志,看完秒懂
查看>>
Mysql InnoDB 数据更新导致锁表
查看>>
Mysql Innodb 锁机制
查看>>
MySQL InnoDB中意向锁的作用及原理探
查看>>
MySQL InnoDB事务隔离级别与锁机制深入解析
查看>>
Mysql InnoDB存储引擎 —— 数据页
查看>>
Mysql InnoDB存储引擎中的checkpoint技术
查看>>