Giter Site home page Giter Site logo

enel645_project's Introduction

ENEL645_Project

Deep learning

Through exploratory data analysis of the development dataset, we gained an understanding for the characteristics of the various disease classes. As shown in Figure 1, brown spots are indicative of CBB (class 0). CBSB (class 1) causes a characteristic yellow or necrotic vein banding which may enlarge and coalesce to form comparatively large, yellow patches (Akinfenwa, 2019). CGM (class 2) and CMD (class 3) showed similar mottled symptoms.

As shown in Figure 2, the dataset is fairly noisy. Some images labelled CBSD (class 1) do not show the leaves at all, instead, they showed images of cassava’s tuberous roots. Some images are incorrectly labelled; and some images do not show a close-up of the leaves, making it difficult to extract features.

Review of the class distribution showed a highly imbalanced dataset. To address the class imbalance, the model was validated using StratifiedKFold, which is a variation of K folds and is preferred for skewed datasets as it maintains the class proportions in the training and validation sets (Bartleby Research, n.d.). The development set is split into 5 folds, with one fold being used for validation and the rest for training. The best model based on the validation accuracy will be submitted to kaggle to be evaluated against the test dataset.

We used a number of pre-trained models to build our model, namely tf_efficientnet_b3_ns, resnext50_32x4d, vit_small_patch16_224, vit_base_patch16_224, and deit_base_patch16_224. To determine the number of epochs suitable for training our model, we plotted the training and validation loss as a function of the number of epochs for tf_efficientnet_b3_ns. As shown in Figure 3, the best validation accuracy was obtained after 10 epochs. After 10 epochs, the model was clearly overfitting on the training data. Similar patterns were observed for the ViT and ResNeXt models, however the DeiT model still showed continual improvement. As shown in Figure 4, deit_base_patch16_224 starts at a much lower accuracy and higher loss (~60% accuracy and ~1 loss). However, the model is continuously improving even after 90 epochs of training and has not shown signs of overfitting. Ideally, we would like to train this model further but due to computing constraints, it was not feasible. In addition, even after 90 epochs, the best accuracy and loss achieved were 75.5% and 0.6779 respectively, which are still worse than EfficientNet, ResNeXt and ViT models. Due to this, we decided that this model would not be investigated further as training took more than 9 hours as well as having no guarantee that it would outperform the other models with additional training.

The performance of the various models chosen can be seen in Table 1. resnext50_32x4d had the best performance, found after 6 epochs. It achieved a validation accuracy of 87.08%, and a training and validation loss of 0.3843 and 0.3907 respectively. The runner-up was tf_efficientnet_b3_ns; it achieved the best performance after 10 epochs, with a validation accuracy of 87.01% and a training and validation loss of 0.3241 and 0.3891 respectively. The vit_base_patch16_224 model performed slightly better than the vit_small_patch16_224 model, with 86.71% against 84.72% in validation accuracy. Both models exhibited signs of overfitting after around 10 epochs.

We also evaluated the performance of different optimizers and learning rate schedulers using the tf_efficientnet_b3_ns. Adam had the best performance compared with AdamP, AdamW and Ranger, as shown in Table 2. Models tuned with Adam, AdamW, AdamP, and Ranger achieved a validation accuracy 87.01%, 86.98%, 86.96%, and 86.71% respectively.

For the learning rate scheduler, we chose CosineAnnealingWarmRestarts, CosineAnnealingLR, and ReduceLROnPlateau. According to Leslie Smith, instead of monotonically decreasing the learning rate, letting the learning rate cyclically vary between reasonable bounds can increase the accuracy of the model in fewer steps (Smith, 2017), (Le, 2018). CosineAnnealingLR and CosineAnnealingWarmRestarts sets the learning rate of each parameter group using a cosine annealing schedule. The "warm restart" resets the learning rate, and uses the tuned parameter weights as the starting point of the restart (Loshchilov & Hutter, 2016). As shown in Table 3 CosineAnnealingWarmRestarts had the best performance, with a validation accuracy of 87.01%, compared to the 86.96% and 86.80% for CosineAnnealingLR and ReduceLROnPlateau respectively.

We further improved the model performance by tuning the loss function to address label noise. Symmetric Cross Entropy and Focal Loss, did not improve the validation accuracy of our model. The best validation accuracy achieved using Symmetric Cross Entropy and Focal Loss were 86.89% and 86.92% respectively, compared to the 87.08% with Cross Entropy Loss. The best model using Symmetric Cross Entropy and Focal Loss were found after 6 and 10 epochs respectively.

On the other hand, Focal Cosine Loss, Taylor Cross Entropy, Label Smoothing and the combination of Taylor Cross Entropy with Label Smoothing, all led to improvement in the validation accuracy. In addition, the models took longer to tune, and found the best model at around 20 epochs.

Focal Cosine Loss achieved the best model performance after 20 epochs, with a validation accuracy of 87.66%. The model also showed lower training and validation losses of 0.1265 and 0.1521 respectively, compared to the 0.3241 and 0.3891 achieved with Cross Entropy Loss.

The best model using Taylor Cross Entropy was found after 17 epochs. It increased the validation accuracy of our ResNext model to 87.69%. However, the training and validation loss were also higher at 0.5114 and 0.5653 respectively.

We also tested Label Smoothing with various “smoothing” factors, as seen in Table 5. We tested smoothing factors of 0, 0.3, 0.5 and 0.7, and a smoothing factor of 0.5 resulted in the highest validation accuracy of 88.20%. The model had high training and validation losses of 1.4217 and 1.4353 respectively. This is expected since we are reducing our confidence on labels by changing the loss target values.

Finally, we evaluated the model performance of combining Taylor Cross Entropy loss with Label Smoothing. The models achieved a validation accuracy of 87.29% and 87.69% respectively using smoothing factors of 0.2 and 0.5 respectively. The results for the various loss functions can be seen in Table 4.

From our validation tests, we found that the best hyperparameters for our dataset were Adam as optimizer, CosineAnnealingWarmRestarts as the learning rate scheduler, and Label Smoothing with a smoothing factor of 0.5 as the criterion, using the resnext50_32x4d model. Our best model achieved a test accuracy of 87.67% upon submission.

In addition to hyperparameter tuning, we applied ensemble learning and TTA. tf_efficientnet_b3_ns, trained with the best hyperparameters found above, achieved a validation accuracy of 87.20% and 87.73% with TTA. The best EfficientNet model with TTA and the best ResNext models were then combined, which achieved a test accuracy of 88.24% upon submission. To further increase our scores, we trained five ResNext and five EfficientNets using our 5 validation folds. Using these additional models, we were able to achieve a test accuracy of 88.69% by applying TTA to both. Finally, we achieved a slightly higher test score of 89.03% with no TTA being applied to the ResNeXt models. A summary of our final test submission scores can be found in Table 6.

enel645_project's People

Contributors

jchoi64 avatar tongxu95 avatar karenzhang7717 avatar shiyuzhou96 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.