Lemmy is a bad place to write comments, so I'll do it here instead. You can delete it entirely later.
So I have a small comments on the model. You are using a classic network, almost as simple as it gets to do classification:
# Define the model architecture
model = keras.Sequential([
keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu", input_shape=(img_height, img_width, 1)),
keras.layers.MaxPooling2D(pool_size=(2, 2)),
keras.layers.Flatten(),
keras.layers.Dense(num_classes) # remove the softmax activation here
])
It's really good as a baseline but you will probably:
- take a long time to train
- Get as good result as your dataset allow on its own but not more
One thing you can use is transfer learning. The classic network used AFAIK for classification is RESNET. If you take this network are train only the last layers you might have better results and shorter training time because you are leveraging the prior knowledge of the pre-trained network.
If you need a paper, I would really consider looking at contrastive learning methods :).
Looking at your method I think my question if I was reviewing it would be more on how you get the labels for the data and how good your labelling is. is it perfect? Is it mostly good? Is it 50% good? If you are not perfect and you have quite some wrong label, then you have one more incentive to use contrastive learning since it is more robust to noise in the label (I can send you the paper on this monday if you want)