-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathLogo_classifier_train.py
More file actions
91 lines (77 loc) · 2.64 KB
/
Logo_classifier_train.py
File metadata and controls
91 lines (77 loc) · 2.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV3Small
from tensorflow.keras.applications.mobilenet_v3 import preprocess_input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import GlobalAveragePooling2D, Conv2D, MaxPooling2D, Dense, Dropout
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
IMG_SIZE = 224
BATCH_SIZE = 16
EPOCHS = 10
FINE_TUNE_EPOCHS = 5
DATA_DIR = 'Logos'
train_datagen = ImageDataGenerator(
preprocessing_function=preprocess_input,
validation_split=0.2,
rotation_range=20,
width_shift_range=0.1,
height_shift_range=0.1,
shear_range=0.1,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest'
)
train_generator = train_datagen.flow_from_directory(
DATA_DIR,
target_size=(IMG_SIZE, IMG_SIZE),
batch_size=BATCH_SIZE,
class_mode='binary',
subset='training',
shuffle=True
)
val_generator = train_datagen.flow_from_directory(
DATA_DIR,
target_size=(IMG_SIZE, IMG_SIZE),
batch_size=BATCH_SIZE,
class_mode='binary',
subset='validation',
shuffle=False
)
base_model = MobileNetV3Small(input_shape=(IMG_SIZE, IMG_SIZE, 3), include_top=False, weights='imagenet')
base_model.trainable = False
x = base_model.output
x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Dropout(0.3)(x)
x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Dropout(0.3)(x)
x = GlobalAveragePooling2D()(x)
x = Dense(128, activation='relu')(x)
x = Dropout(0.4)(x)
authenticity_predictions = Dense(1, activation='sigmoid')(x)
model = Model(inputs=base_model.input, outputs=authenticity_predictions)
early_stop = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True, verbose=1)
checkpoint = ModelCheckpoint('authenticity_classifier_mobilenet.keras', monitor='val_accuracy', save_best_only=True, verbose=1)
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
print("\n=== Training Phase 1: Frozen Base Model ===")
model.fit(
train_generator,
validation_data=val_generator,
epochs=EPOCHS,
callbacks=[early_stop, checkpoint],
verbose=1
)
base_model.trainable = True
for layer in base_model.layers[:-30]:
layer.trainable = False
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model.fit(
train_generator,
validation_data=val_generator,
epochs=FINE_TUNE_EPOCHS,
callbacks=[early_stop, checkpoint],
verbose=1
)