| import tensorflow as tf |
| import os |
| import matplotlib.pyplot as plt |
| from dotenv import load_dotenv |
| from tensorflow.keras.preprocessing.image import ImageDataGenerator |
|
|
| load_dotenv() |
|
|
| |
| BATCH_SIZE = 32 |
| IMG_SIZE = (224, 224) |
| TRAIN_DATASET = os.getenv("TRAIN_DATASET") |
| EPOCHS = 8 |
| OPTIMIZER = 'adam' |
| LOSS_FUNC = 'binary_crossentropy' |
|
|
| |
| def load_data(): |
| datagen = ImageDataGenerator( |
| validation_split=0.2, |
| rescale=1./255, |
| horizontal_flip=True, |
| zoom_range=0.2 |
| ) |
| |
| train_data = datagen.flow_from_directory( |
| directory=TRAIN_DATASET, |
| target_size=IMG_SIZE, |
| batch_size=BATCH_SIZE, |
| class_mode="binary", |
| subset="training", |
| shuffle=True |
| ) |
| |
| val_data = datagen.flow_from_directory( |
| directory=TRAIN_DATASET, |
| target_size=IMG_SIZE, |
| batch_size=BATCH_SIZE, |
| class_mode="binary", |
| subset="validation", |
| shuffle=True |
| ) |
| |
| return train_data, val_data |
|
|
| |
| def build_model(): |
| model = tf.keras.Sequential([ |
| |
| |
| |
| tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(*IMG_SIZE, 3)), |
| tf.keras.layers.MaxPooling2D(2,2), |
| |
| |
| |
| tf.keras.layers.Conv2D(64, (3,3), activation='relu'), |
| tf.keras.layers.MaxPooling2D(2,2), |
| |
| |
| |
| tf.keras.layers.Conv2D(128, (3,3), activation='relu'), |
| tf.keras.layers.MaxPooling2D(2,2), |
| tf.keras.layers.Flatten(), |
| tf.keras.layers.Dense(512, activation='relu'), |
| tf.keras.layers.Dense(1, activation='sigmoid') |
| ]) |
| |
| |
| model.compile(optimizer=OPTIMIZER, |
| loss=LOSS_FUNC, |
| metrics=['accuracy']) |
| |
| return model |
|
|
| |
| def main(): |
| train_data, val_data = load_data() |
| model = build_model() |
| |
| |
| history = model.fit( |
| train_data, |
| epochs = EPOCHS, |
| validation_data=val_data |
| ) |
| |
| |
| model.save("cat_dog_model.h5") |
|
|
| |
| acc = history.history['accuracy'] |
| loss = history.history['loss'] |
| val_acc = history.history['val_accuracy'] |
| val_loss = history.history['val_loss'] |
| plt.plot(acc, label='Train Accuracy') |
| plt.plot(val_acc, label='Validation Accuracy') |
| plt.plot(loss, label='Train Loss') |
| plt.plot(val_loss, label='Validation Loss') |
| plt.legend() |
| plt.title('Training Accuracy') |
| plt.show() |
|
|
| if __name__ == "__main__": |
| main() |
|
|