-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathminiVGGNet_Flower17.py
More file actions
57 lines (43 loc) · 1.95 KB
/
miniVGGNet_Flower17.py
File metadata and controls
57 lines (43 loc) · 1.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# -*- coding: utf-8 -*-
"""
Created on Tue Jan 8 16:28:59 2019
@author: DELL
"""
from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from Preprocessing.SimpleProcessor import SimplePreprocessor
from Preprocessing.AspectAwareProcessor import AspectAwareProcessor
from Preprocessing.ImageToArrayProcessor import ImageToArrayProcessor
from Dataset.SimpleDatasetLoader import SimpleDatasetLoader
from keras.optimizers import SGD
from NeurualNetwork.ConvolutionNeuralNetwork.miniVGGNet import MiniVGGNet
from imutils import paths
import matplotlib.pyplot as plt
import numpy as np
imagePaths = list(paths.list_images("Dataset/oxfordflower17/jpg"))
print("[INFO]Load images")
aap = AspectAwareProcessor(64, 64)
imgToArr = ImageToArrayProcessor()
loader = SimpleDatasetLoader(processors=[aap, imgToArr])
data, label = loader.load(imagePaths, verbose=500)
data = data.astype('float')/255.
X_train, X_test, y_train, y_test = train_test_split(data, label, test_size=0.25, random_state=42)
lb = LabelBinarizer()
y_train = lb.fit_transform(y_train)
y_test = lb.fit_transform(y_test)
numOfEpoch = 100
model = MiniVGGNet.build(64, 64, 3, 17)
model.compile(SGD(0.01, 0.9, nesterov=True), loss = 'categorical_crossentropy', metrics=['accuracy'])
H = model.fit(X_train, y_train, validation_data=(X_test, y_test), batch_size=32, epochs=numOfEpoch)
prediction = model.predict(X_test, batch_size=32)
print(classification_report(y_test.argmax(1), prediction.argmax(1)))
fig = plt.figure()
plt.plot(np.arange(0, numOfEpoch), H.history['loss'], label='training loss')
plt.plot(np.arange(0, numOfEpoch), H.history['val_loss'], label='validation loss')
plt.plot(np.arange(0, numOfEpoch), H.history['acc'], label='accuracy')
plt.plot(np.arange(0, numOfEpoch), H.history['val_acc'], label='validation accuracy')
plt.title('Accuracy and Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss|Accuracy')
plt.legend()