Deep Learning

101_Overfitting

elif 2024. 3. 10. 22:42

In the previous post(100_DNN Example using PyTorch), I explored solving the MNIST classification problem with a DNN that added three hidden layers to an ANN. During the training, while the Training loss decreased normally, the Validation loss decreased to a certain level and then began to increase. This was mentioned as a likely occurrence of overfitting. In this post, I will discuss what overfitting is.

 

Overfitting is a common problem in machine learning, where the model performs well on training data but fails to generalize effectively to unseen data.

In machine learning, our goal is train a model using the training dataset and then use the trained model to make predictions on new data. We complie data not included in the training data into a test dataset and measure the model's performance with this test data to evaluate the model's performance on new data. Since the test dataset cannot be referenced by the model during the training process, it's crucial to ensure that the model can make accurate predictions on test data using only the training dataset.

 

Fundamentally, machine learning models are trained to minimize the loss function. Therefore, as training progresses, the model's dicision boundary becomes fitted to the training dataset. Let's look at the following figure.

 

Raschka, Sebastian, et al. Machine Learning with PyTorch and Scikit-Learn: Develop Machine Learning and Deep Learning Models with Python. United Kingdom, Packt Publishing, 2022.

 

If you look at the figure, you can see the model exhibits underfitting, where it fails to adequately capture the patterns in the training data, leading to poor learning performance, making predictions on unseen data practically impossible. On the other hand, if the model becomes too closely fitted to the training dataset, that is, if overfitting occurs, it can be seen as the model learning information specific to the training dataset rather than generalizing the data's strucutre or pattern.

 

If the distributions of the training data and test data were identical, fitting the model more closely to the training dataset would improve its performance and increase accuracy. However, in almost all cases, it's impossible to secure enough data for the training dataset to fully represent the entirely of the data. Therefore, a model that is slightly less fitted to the training dataset may show higher accuracy on test data or validation data.

 

Therefore, the issue where a model becomes overly fitted to the training dataset, showing lower prediction errors on the training data but higher prediction errors on the test data, is refereed to as the overfitting problem. The overfitting issue can arise due to various factors, such as having insufficient training data or the model being too complex.

 

Ghojogh, B., & Crowley, M. (2019). The theory behind overfitting, cross validation, regularization, bagging, and boosting: tutorial.  arXiv preprint arXiv:1905.12787 .

 

As illustrated in the figure, in the underfitting phase, both training loss and test loss decrease together. However, in the overfitting phase, even though the training loss continues to decrease, the test loss begins to increase. Therefore, the objective in training predictive models is to eliminate underfittin while stopping the training just before overfitting occures. One method to identify overfitting is by examining the graph of the loss function to see if the test loss starts to increase as training loss decreases. Another mothod is using a validation dataset to determine if overfitting has occured.

 

The validation dataset is a separate dataset used to reference during the training process to determine if overfitting has occured. It is typically constructed by extracting some data from the training data. When three types of dataset exist training, test, and validation, the training generally focuses on minimizing the prediction error for the training data, while the prediction error on the validation data is calculated. If the prediction error on the validation data increases, training is terminated, if not, training continues. The final performance of the trained model is then evaluated by measuring the prediction error on the test data.

 

Throught this process, we can evaluate how accurately the current model predicts data it has not referenced during the training process, and use the evaluation results as a criterion for terminating the training, thereby indirectly preventing overfitting.

The biggest difference between test data and validation data is whether they are referenced during the training process. The validation dataset can be referenced during training, but the test dataset cannot. Although both the validation dataset and test dataset are used to evaluate the performance of the model, the validation dataset is exclusively used to decide when to stop training, whereas the test dataset is used to assess the model's final accuracy.

 

Typically, when overfitting ovvures, it can be prevented by simplifying the model either by reducing the size of the model's layers or the number of neurons, using Dropout, or applying regularization to ensure the model is not excessively complex compared to the training data.

In the next post, I'll explore solving overfitting through these methods.

'Deep Learning' 카테고리의 다른 글

103_Code Modification  (0) 2024.03.12
102_Dropout Layers  (0) 2024.03.11
100_DNN Example using PyTorch  (0) 2024.03.09
99_Simple CNN Example Using PyTorch  (0) 2024.03.08
98_Simple ANN Example Using PyTorch(2)  (0) 2024.03.07